Skip to content

Commit

Permalink
tcp: MPTCP support
Browse files Browse the repository at this point in the history
For Linux define IPPROTO_MPTCP if not available to 266
For other platforms - check if constant is defined
Fall back to TCP if not available and for Unix sockets
  • Loading branch information
shramov committed Dec 13, 2023
1 parent ed3d323 commit 4332349
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
38 changes: 38 additions & 0 deletions python/test/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tll.channel as C
from tll.error import TLLError
from tll.test_util import Accum, ports
from tll.processor import Loop

import os
import pytest
Expand Down Expand Up @@ -326,3 +327,40 @@ def test_client_scheme():
c.open()

assert [m.name for m in s.scheme.messages] == ['Test']

@pytest.mark.skipif(not hasattr(socket, 'IPPROTO_MPTCP'), reason="MPTCP not available")
@pytest.mark.parametrize("client", ["yes", "no"])
@pytest.mark.parametrize("server", ["yes", "no"])
@pytest.mark.parametrize("host", [f"127.0.0.1:{ports.TCP4}", f"::1:{ports.TCP6}", "./tcp.sock"])
def test_mptcp(host, client, server):
s = Accum(f'tcp://{host};mode=server;dump=frame', mptcp=server)
c = Accum(f'tcp://{host};mode=client;dump=frame', mptcp=client)

loop = Loop()

loop.add(s)
loop.add(c)

s.open()
c.open()

try:
for _ in range(10):
if s.result:
break
loop.step(0.001)
addr = s.result[-1].addr
s.result = []

c.post(b'xxx', seq=10)
s.post(b'zzz', seq=20, addr=addr)

for _ in range(10):
if s.result and c.result:
break
loop.step(0.001)
assert [(m.data.tobytes(), m.seq) for m in s.result] == [(b'xxx', 10)]
assert [(m.data.tobytes(), m.seq) for m in c.result] == [(b'zzz', 20)]
finally:
s.close()
c.close()
1 change: 1 addition & 0 deletions src/tll/channel/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct tcp_settings_t {
bool timestamping = false;
bool keepalive = true;
bool nodelay = false;
bool mptcp = false;
};

struct tcp_connect_t {
Expand Down
26 changes: 24 additions & 2 deletions src/tll/channel/tcp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include <linux/sockios.h>
#include <linux/net_tstamp.h>
#include <sys/ioctl.h>

#ifndef IPPROTO_MPTCP
#define IPPROTO_MPTCP 262
#endif
#endif

#if defined(__APPLE__) || defined(__FreeBSD__)
Expand Down Expand Up @@ -56,6 +60,23 @@ size_t _fill_iovec(size_t full, struct iovec * iov, const Arg &arg, const Args &
return _fill_iovec(full + tll::memoryview_api<Arg>::size(arg), iov + 1, std::forward<const Args &>(args)...);
}

inline int check_mptcp_proto(tll::Logger &log, const tcp_settings_t &settings, int af)
{
if (!settings.mptcp)
return 0;
if (af == AF_UNIX) {
log.info("MPTCP not supported for Unix sockets, fall back to TCP");
return 0;
}

#ifdef IPPROTO_MPTCP
return IPPROTO_MPTCP;
#else
log.warning("MPTCP not supported on this platform, fall back to TCP");
return 0;
#endif
}

} // namespace _

template <typename T>
Expand Down Expand Up @@ -359,6 +380,7 @@ int TcpClient<T, S>::_init(const tll::Channel::Url &url, tll::Channel *master)
_settings.sndbuf = reader.getT("sndbuf", util::Size { 0 });
_settings.rcvbuf = reader.getT("rcvbuf", util::Size { 0 });
_settings.buffer_size = reader.getT("buffer-size", util::Size { 64 * 1024 });
_settings.mptcp = reader.getT("mptcp", false);
if (!reader)
return this->_log.fail(EINVAL, "Invalid url: {}", reader.error());

Expand Down Expand Up @@ -406,7 +428,7 @@ int TcpClient<T, S>::_open(const ConstConfig &url)
std::swap(_addr_list, *addr);
_addr = _addr_list.begin();

auto fd = socket((*_addr)->sa_family, SOCK_STREAM, 0);
auto fd = socket((*_addr)->sa_family, SOCK_STREAM, _::check_mptcp_proto(this->_log, _settings, (*_addr)->sa_family));
if (fd == -1)
return this->_log.fail(errno, "Failed to create socket: {}", strerror(errno));
this->_update_fd(fd);
Expand Down Expand Up @@ -619,7 +641,7 @@ int TcpServer<T, C>::_bind(const tll::network::sockaddr_any &addr)
static constexpr int sflags = SOCK_STREAM;
#endif

tll::network::scoped_socket fd(socket(addr->sa_family, sflags, 0));
tll::network::scoped_socket fd(socket(addr->sa_family, sflags, _::check_mptcp_proto(this->_log, _settings, addr->sa_family)));
if (fd == -1)
return this->_log.fail(errno, "Failed to create socket: {}", strerror(errno));

Expand Down

0 comments on commit 4332349

Please sign in to comment.