Skip to content

Commit

Permalink
chore: remove duplication mbedtls initialization code in accept/conne…
Browse files Browse the repository at this point in the history
…ct and un-expose mbedtls context
  • Loading branch information
lchenut committed Jul 31, 2024
1 parent ffa8a51 commit d003d20
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 108 deletions.
163 changes: 110 additions & 53 deletions webrtc/dtls/dtls_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import chronos, chronicles
import ../errors, ../stun/[stun_connection]
import ./dtls_utils

import mbedtls/ssl
import mbedtls/ssl_cookie
Expand All @@ -29,17 +30,17 @@ logScope:
topics = "webrtc dtls_conn"

type
MbedTLSCtx* = object
ssl*: mbedtls_ssl_context
config*: mbedtls_ssl_config
cookie*: mbedtls_ssl_cookie_ctx
cache*: mbedtls_ssl_cache_context
timer*: mbedtls_timing_delay_context
pkey*: mbedtls_pk_context
srvcert*: mbedtls_x509_crt

ctr_drbg*: mbedtls_ctr_drbg_context
entropy*: mbedtls_entropy_context
MbedTLSCtx = object
ssl: mbedtls_ssl_context
config: mbedtls_ssl_config
cookie: mbedtls_ssl_cookie_ctx
cache: mbedtls_ssl_cache_context
timer: mbedtls_timing_delay_context
pkey: mbedtls_pk_context
srvcert: mbedtls_x509_crt

ctr_drbg: mbedtls_ctr_drbg_context
entropy: mbedtls_entropy_context

DtlsConn* = ref object
# DtlsConn is a Dtls connection receiving and sending data using
Expand All @@ -64,6 +65,46 @@ type
# Mbed-TLS contexts
ctx*: MbedTLSCtx

proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
state: cint, pflags: ptr uint32): cint {.cdecl.} =
# verify is the procedure called by mbedtls when receiving the remote
# certificate. It's usually used to verify the validity of the certificate.
# We use this procedure to store the remote certificate as it's mandatory
# to have it for the Prologue of the Noise protocol, aswell as the localCertificate.
var self = cast[DtlsConn](ctx)
let cert = pcert[]

self.remoteCert = newSeq[byte](cert.raw.len)
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
return 0

proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
# we store the future of this write and await it after the end of the
# function (see write or dtlsHanshake for example).
var self = cast[DtlsConn](ctx)
var toWrite = newSeq[byte](len)
if len > 0:
copyMem(addr toWrite[0], buf, len)
trace "dtls send", len
self.sendFuture = self.conn.write(toWrite)
result = len.cint

proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
# As we cannot asynchronously await for data to be received, we use a data received
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
let self = cast[DtlsConn](ctx)
if self.dataRecv.len() == 0:
return MBEDTLS_ERR_SSL_WANT_READ

copyMem(buf, addr self.dataRecv[0], self.dataRecv.len())
result = self.dataRecv.len().cint
self.dataRecv = @[]
trace "dtls receive", len, result

proc new*(T: type DtlsConn, conn: StunConn, laddr: TransportAddress): T =
## Initialize a Dtls Connection
##
Expand All @@ -73,6 +114,64 @@ proc new*(T: type DtlsConn, conn: StunConn, laddr: TransportAddress): T =
self.closeEvent = newAsyncEvent()
return self

proc dtlsConnInit(self: DtlsConn) =
mb_ssl_init(self.ctx.ssl)
mb_ssl_config_init(self.ctx.config)
mb_ssl_conf_rng(self.ctx.config, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg)
mb_ssl_conf_read_timeout(self.ctx.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(self.ctx.config, self.ctx.srvcert.next, nil)
mb_ssl_set_timer_cb(self.ctx.ssl, self.ctx.timer)
mb_ssl_set_verify(self.ctx.ssl, verify, self)
mb_ssl_set_bio(self.ctx.ssl, cast[pointer](self), dtlsSend, dtlsRecv, nil)

proc acceptInit*(
self: DtlsConn,
ctr_drbg: mbedtls_ctr_drbg_context,
pkey: mbedtls_pk_context,
srvcert: mbedtls_x509_crt,
localCert: seq[byte]
) =
self.ctx.ctr_drbg = ctr_drbg
self.ctx.pkey = pkey
self.ctx.srvcert = srvcert
self.localCert = localCert

self.dtlsConnInit()
mb_ssl_cookie_init(self.ctx.cookie)
mb_ssl_cache_init(self.ctx.cache)
mb_ssl_config_defaults(
self.ctx.config,
MBEDTLS_SSL_IS_SERVER,
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT
)
mb_ssl_conf_own_cert(self.ctx.config, self.ctx.srvcert, self.ctx.pkey)
mb_ssl_cookie_setup(self.ctx.cookie, mbedtls_ctr_drbg_random, self.ctx.ctr_drbg)
mb_ssl_conf_dtls_cookies(self.ctx.config, addr self.ctx.cookie)
mb_ssl_setup(self.ctx.ssl, self.ctx.config)
mb_ssl_session_reset(self.ctx.ssl)
mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)

proc connectInit*(
self: DtlsConn,
ctr_drbg: mbedtls_ctr_drbg_context
) =
self.ctx.ctr_drbg = ctr_drbg
self.ctx.pkey = self.ctx.ctr_drbg.generateKey()
self.ctx.srvcert = self.ctx.ctr_drbg.generateCertificate(self.ctx.pkey)
self.localCert = newSeq[byte](self.ctx.srvcert.raw.len)
copyMem(addr self.localCert[0], self.ctx.srvcert.raw.p, self.ctx.srvcert.raw.len)

self.dtlsConnInit()
mb_ssl_config_defaults(
self.ctx.config,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT
)
mb_ssl_setup(self.ctx.ssl, self.ctx.config)
mb_ssl_conf_authmode(self.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)

proc join*(self: DtlsConn) {.async: (raises: [CancelledError]).} =
## Wait for the Dtls Connection to be closed
##
Expand Down Expand Up @@ -175,45 +274,3 @@ proc remoteCertificate*(conn: DtlsConn): seq[byte] =

proc localCertificate*(conn: DtlsConn): seq[byte] =
conn.localCert

# -- MbedTLS Callbacks --

proc verify*(ctx: pointer, pcert: ptr mbedtls_x509_crt,
state: cint, pflags: ptr uint32): cint {.cdecl.} =
# verify is the procedure called by mbedtls when receiving the remote
# certificate. It's usually used to verify the validity of the certificate.
# We use this procedure to store the remote certificate as it's mandatory
# to have it for the Prologue of the Noise protocol, aswell as the localCertificate.
var self = cast[DtlsConn](ctx)
let cert = pcert[]

self.remoteCert = newSeq[byte](cert.raw.len)
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
return 0

proc dtlsSend*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
# we store the future of this write and await it after the end of the
# function (see write or dtlsHanshake for example).
var self = cast[DtlsConn](ctx)
var toWrite = newSeq[byte](len)
if len > 0:
copyMem(addr toWrite[0], buf, len)
trace "dtls send", len
self.sendFuture = self.conn.write(toWrite)
result = len.cint

proc dtlsRecv*(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
# As we cannot asynchronously await for data to be received, we use a data received
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
let self = cast[DtlsConn](ctx)
if self.dataRecv.len() == 0:
return MBEDTLS_ERR_SSL_WANT_READ

copyMem(buf, addr self.dataRecv[0], self.dataRecv.len())
result = self.dataRecv.len().cint
self.dataRecv = @[]
trace "dtls receive", len, result
58 changes: 3 additions & 55 deletions webrtc/dtls/dtls_transport.nim
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# This file may not be copied, modified, or distributed except according to
# those terms.

import times, deques, tables, sequtils
import deques, tables, sequtils
import chronos, chronicles
import ./[dtls_utils, dtls_connection], ../errors,
../stun/[stun_connection, stun_transport]
Expand Down Expand Up @@ -86,37 +86,8 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
## Accept a Dtls Connection
##
var res = DtlsConn.new(await self.transport.accept(), self.laddr)
res.acceptInit(self.ctr_drbg, self.serverPrivKey, self.serverCert, self.localCert)

mb_ssl_init(res.ctx.ssl)
mb_ssl_config_init(res.ctx.config)
mb_ssl_cookie_init(res.ctx.cookie)
mb_ssl_cache_init(res.ctx.cache)

res.ctx.ctr_drbg = self.ctr_drbg
res.ctx.entropy = self.entropy

res.ctx.pkey = self.serverPrivKey
res.ctx.srvcert = self.serverCert
res.localCert = self.localCert

mb_ssl_config_defaults(
res.ctx.config,
MBEDTLS_SSL_IS_SERVER,
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT
)
mb_ssl_conf_rng(res.ctx.config, mbedtls_ctr_drbg_random, res.ctx.ctr_drbg)
mb_ssl_conf_read_timeout(res.ctx.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(res.ctx.config, res.ctx.srvcert.next, nil)
mb_ssl_conf_own_cert(res.ctx.config, res.ctx.srvcert, res.ctx.pkey)
mb_ssl_cookie_setup(res.ctx.cookie, mbedtls_ctr_drbg_random, res.ctx.ctr_drbg)
mb_ssl_conf_dtls_cookies(res.ctx.config, addr res.ctx.cookie)
mb_ssl_set_timer_cb(res.ctx.ssl, res.ctx.timer)
mb_ssl_setup(res.ctx.ssl, res.ctx.config)
mb_ssl_session_reset(res.ctx.ssl)
mb_ssl_set_verify(res.ctx.ssl, verify, res)
mb_ssl_conf_authmode(res.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
mb_ssl_set_bio(res.ctx.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
while true:
try:
self.connections[res.raddr] = res
Expand All @@ -133,30 +104,7 @@ proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
## Connect to a remote address, creating a Dtls Connection
var res = DtlsConn.new(await self.transport.connect(raddr), self.laddr)

mb_ssl_init(res.ctx.ssl)
mb_ssl_config_init(res.ctx.config)

res.ctx.ctr_drbg = self.ctr_drbg
res.ctx.entropy = self.entropy

res.ctx.pkey = res.ctx.ctr_drbg.generateKey()
res.ctx.srvcert = res.ctx.ctr_drbg.generateCertificate(res.ctx.pkey)
res.localCert = newSeq[byte](res.ctx.srvcert.raw.len)
copyMem(addr res.localCert[0], res.ctx.srvcert.raw.p, res.ctx.srvcert.raw.len)

mb_ssl_config_defaults(res.ctx.config,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT)
mb_ssl_conf_rng(res.ctx.config, mbedtls_ctr_drbg_random, res.ctx.ctr_drbg)
mb_ssl_conf_read_timeout(res.ctx.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(res.ctx.config, res.ctx.srvcert.next, nil)
mb_ssl_set_timer_cb(res.ctx.ssl, res.ctx.timer)
mb_ssl_setup(res.ctx.ssl, res.ctx.config)
mb_ssl_set_verify(res.ctx.ssl, verify, res)
mb_ssl_conf_authmode(res.ctx.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
mb_ssl_set_bio(res.ctx.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
res.connectInit(self.ctr_drbg)

try:
self.connections[raddr] = res
Expand Down

0 comments on commit d003d20

Please sign in to comment.