Skip to content

Commit

Permalink
chore: split files
Browse files Browse the repository at this point in the history
  • Loading branch information
lchenut committed Jul 26, 2024
1 parent c8a9fd3 commit 9b5c58d
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 109 deletions.
99 changes: 99 additions & 0 deletions webrtc/sctp/sctp_connection.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.

import chronos, chronicles
import usrsctp
import ./sctp_utils
import ../errors

logScope:
topics = "webrtc sctp_connection"

proc sctpStrerror(error: cint): cstring {.importc: "strerror", cdecl, header: "<string.h>".}

type
SctpState = enum
Connecting
Connected
Closed

SctpMessageParameters* = object
protocolId*: uint32
streamId*: uint16
endOfRecord*: bool
unordered*: bool

SctpMessage* = ref object
data*: seq[byte]
info: sctp_recvv_rn
params*: SctpMessageParameters

SctpConn* = ref object
conn*: DtlsConn
state: SctpState
connectEvent: AsyncEvent
acceptEvent: AsyncEvent
readLoop: Future[void]
udp: DatagramTransport
address: TransportAddress
sctpSocket: ptr socket
dataRecv: AsyncQueue[SctpMessage]
sentFuture: Future[void]

proc new(T: typedesc[SctpConn], conn: DtlsConn): T =
T(conn: conn,
state: Connecting,
connectEvent: AsyncEvent(),
acceptEvent: AsyncEvent(),
dataRecv: newAsyncQueue[SctpMessage]()
)

proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
# Used by DataChannel, returns SctpMessage in order to get the stream
# and protocol ids
return await self.dataRecv.popFirst()

proc toFlags(params: SctpMessageParameters): uint16 =
if params.endOfRecord:
result = result or SCTP_EOR
if params.unordered:
result = result or SCTP_UNORDERED

proc write*(self: SctpConn, buf: seq[byte],
sendParams = default(SctpMessageParameters)) {.async.} =
# Used by DataChannel, writes buf on the Dtls connection.
trace "Write", buf

var cpy = buf
let sendvErr =
if sendParams == default(SctpMessageParameters):
# If writes is called by DataChannel, sendParams should never
# be the default value. This split is useful for testing.
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
else:
var sendInfo = sctp_sndinfo(
snd_sid: sendParams.streamId,
snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags)
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
SCTP_SENDV_SNDINFO.cuint, 0)
if sendvErr < 0:
raise newException(WebRtcError, $(sctpStrerror(sendvErr)))

proc write*(self: SctpConn, s: string) {.async.} =
await self.write(s.toBytes())

proc close*(self: SctpConn) {.async.} =
self.usrsctpAwait:
self.sctpSocket.usrsctp_close()
usrsctp_deregister_address(cast[pointer](self))
115 changes: 6 additions & 109 deletions webrtc/sctp.nim → webrtc/sctp/sctp_transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import tables, bitops, posix, strutils, sequtils
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2]
import usrsctp
import dtls/dtls
import ../dtls/dtls_transport
import ./sctp_connection
import binary_serialization

export chronicles
Expand All @@ -31,39 +32,9 @@ proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}

type
SctpError* = object of CatchableError

SctpState = enum
Connecting
Connected
Closed

SctpMessageParameters* = object
protocolId*: uint32
streamId*: uint16
endOfRecord*: bool
unordered*: bool

SctpMessage* = ref object
data*: seq[byte]
info: sctp_recvv_rn
params*: SctpMessageParameters

SctpConn* = ref object
conn*: DtlsConn
state: SctpState
connectEvent: AsyncEvent
acceptEvent: AsyncEvent
readLoop: Future[void]
sctp: Sctp
udp: DatagramTransport
address: TransportAddress
sctpSocket: ptr socket
dataRecv: AsyncQueue[SctpMessage]
sentFuture: Future[void]

Sctp* = ref object
dtls: Dtls
laddr*: TransportAddress
udp: DatagramTransport
connections: Table[TransportAddress, SctpConn]
gotConnection: AsyncEvent
Expand All @@ -72,7 +43,6 @@ type
sockServer: ptr socket
pendingConnections: seq[SctpConn]
pendingConnections2: Table[SockAddr, SctpConn]
sentAddress: TransportAddress
sentFuture: Future[void]

# These three objects are used for debugging/trace only
Expand Down Expand Up @@ -106,79 +76,6 @@ proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure =
# padding; could use `size.inc(-size %% 4)` instead but it lacks clarity
size.inc(1)

# -- Asynchronous wrapper --

template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
# usrsctpAwait is template which set `sentFuture` to nil then calls (usually)
# an usrsctp function. If during the synchronous run of the usrsctp function
# `sendCallback` is called, then `sentFuture` is set and waited.
self.sentFuture = nil
when type(body) is void:
body
if self.sentFuture != nil: await self.sentFuture
else:
let res = body
if self.sentFuture != nil: await self.sentFuture
res

# -- SctpConn --

proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
T(conn: conn,
sctp: sctp,
state: Connecting,
connectEvent: AsyncEvent(),
acceptEvent: AsyncEvent(),
dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure?
)

proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
# Used by DataChannel, returns SctpMessage in order to get the stream
# and protocol ids
return await self.dataRecv.popFirst()

proc toFlags(params: SctpMessageParameters): uint16 =
if params.endOfRecord:
result = result or SCTP_EOR
if params.unordered:
result = result or SCTP_UNORDERED

proc write*(self: SctpConn, buf: seq[byte],
sendParams = default(SctpMessageParameters)) {.async.} =
# Used by DataChannel, writes buf on the Dtls connection.
trace "Write", buf
self.sctp.sentAddress = self.address

var cpy = buf
let sendvErr =
if sendParams == default(SctpMessageParameters):
# If writes is called by DataChannel, sendParams should never
# be the default value. This split is useful for testing.
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
else:
var sendInfo = sctp_sndinfo(
snd_sid: sendParams.streamId,
# TODO: swapBytes => htonl?
snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags)
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
SCTP_SENDV_SNDINFO.cuint, 0)
if sendvErr < 0:
# TODO: throw an exception
perror("usrsctp_sendv")

proc write*(self: SctpConn, s: string) {.async.} =
await self.write(s.toBytes())

proc close*(self: SctpConn) {.async.} =
self.usrsctpAwait:
self.sctpSocket.usrsctp_close()
usrsctp_deregister_address(cast[pointer](self))

# -- usrsctp receive data callbacks --

proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
Expand Down Expand Up @@ -313,12 +210,13 @@ proc stopServer*(self: Sctp) =
pc.sctpSocket.usrsctp_close()
self.sockServer.usrsctp_close()

proc init*(self: Sctp, dtls: Dtls, laddr: TransportAddress) =
proc new*(T: type Sctp, dtls: Dtls) =
self.gotConnection = newAsyncEvent()
self.timersHandler = timersHandler()
self.dtls = dtls

usrsctp_init_nothreads(laddr.port.uint16, sendCallback, printf)

usrsctp_init_nothreads(dtls.laddr.port.uint16, sendCallback, printf)
discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE)
discard usrsctp_sysctl_set_sctp_ecn_enable(1)
usrsctp_register_address(cast[pointer](self))
Expand Down Expand Up @@ -392,7 +290,6 @@ proc connect*(self: Sctp,
sconn.sconn_family = AF_CONN
sconn.sconn_port = htons(sctpPort)
sconn.sconn_addr = cast[pointer](conn)
self.sentAddress = address
usrsctp_register_address(cast[pointer](conn))
conn.readLoop = conn.readLoopProc()
let connErr = self.usrsctpAwait:
Expand Down
22 changes: 22 additions & 0 deletions webrtc/sctp/sctp_utils.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.

template usrsctpAwait*(self: untyped, body: untyped): untyped =
# usrsctpAwait is template which set `sentFuture` to nil then calls (usually)
# an usrsctp function. If during the synchronous run of the usrsctp function
# `sendCallback` is called, then `sentFuture` is set and waited.
# self should be Sctp or SctpConn
self.sentFuture = nil
when type(body) is void:
(body)
if self.sentFuture != nil: await self.sentFuture
else:
let res = (body)
if self.sentFuture != nil: await self.sentFuture
res

0 comments on commit 9b5c58d

Please sign in to comment.