Skip to content

Commit

Permalink
Add timeout options
Browse files Browse the repository at this point in the history
  • Loading branch information
composerinteralia committed Apr 22, 2024
1 parent 0dcc7d0 commit d81b4e4
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 166 deletions.
20 changes: 20 additions & 0 deletions lib/nocturne.rb
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ def closed?
@conn.closed?
end

def read_timeout=(timeout)
raise ConnectionClosed if @conn.closed?
@options[:read_timeout] = timeout
end

def read_timeout
raise ConnectionClosed if @conn.closed?
@options[:read_timeout]
end

def write_timeout=(timeout)
raise ConnectionClosed if @conn.closed?
@options[:write_timeout] = timeout
end

def write_timeout
raise ConnectionClosed if @conn.closed?
@options[:write_timeout]
end

private

def connect
Expand Down
3 changes: 3 additions & 0 deletions lib/nocturne/error.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ class ConnectionClosed < ConnectionError

class QueryError < Error
end

class TimeoutError < Error
end
end
5 changes: 2 additions & 3 deletions lib/nocturne/protocol.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ def self.error(packet, klass)
end

def self.read_error(packet)
packet.int8
packet.skip(1) # Error
code = packet.int16
packet.strn(1)
packet.strn(5)
packet.skip(6) # SQL state
message = packet.eof_str
[code, message]
end
Expand Down
7 changes: 7 additions & 0 deletions lib/nocturne/protocol/handshake.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def engage
private

def server_handshake
original_read_timeout = @options[:read_timeout]
@options[:read_timeout] = @options[:connect_timeout] || @options[:write_timeout]

@conn.read_packet do |handshake|
raise Protocol.error(handshake, ConnectionError) if handshake.err?

_protocol_version = handshake.int8
@server_version = handshake.nulstr
_thread_id = handshake.int32
Expand All @@ -54,6 +59,8 @@ def server_handshake
@auth_plugin_data = auth_plugin_data + handshake.strn([13, auth_plugin_data_len - 8].max)
@auth_plugin_name = handshake.nulstr
end
ensure
@options[:read_timeout] = original_read_timeout
end

def ssl_request
Expand Down
2 changes: 2 additions & 0 deletions lib/nocturne/protocol/query.rb
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def read_rows(column_count)
more_rows = true
while more_rows
@conn.read_packet do |row|
raise Protocol.error(row, QueryError) if row.err?

if row.eof?
row.skip(1)
@conn.update_status(
Expand Down
35 changes: 19 additions & 16 deletions lib/nocturne/socket.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# frozen_string_literal:true

# standard:disable Lint/MissingCopEnableDirective
# standard:disable Style/YodaCondition

require "socket"
require "openssl"

Expand All @@ -17,8 +20,8 @@ def recv(buffer)
loop do
result = @sock.read_nonblock(MAX_BYTES, buffer, exception: false)

if :wait_readable == result # standard:disable Style/YodaCondition
IO.select(@select_sock)
if :wait_readable == result
IO.select(@select_sock, nil, nil, @options[:read_timeout]) || raise(TimeoutError)
else
return result
end
Expand All @@ -29,8 +32,8 @@ def sendmsg(data)
loop do
result = @sock.write_nonblock(data, exception: false)

if :wait_writable == result # standard:disable Style/YodaCondition
IO.select(nil, @select_sock)
if :wait_writable == result
IO.select(nil, @select_sock, nil, @options[:write_timeout]) || raise(TimeoutError)
else
return result
end
Expand All @@ -50,7 +53,8 @@ def ssl_sock
def connect(options)
if options[:host]
sock = ::Socket.new(::Socket::AF_INET, ::Socket::SOCK_STREAM)
sock.connect ::Socket.pack_sockaddr_in(options[:port] || 3306, options[:host] || "localhost")
addr = ::Socket.pack_sockaddr_in(options[:port] || 3306, options[:host] || "localhost")
sock.connect_nonblock(addr, exception: false)
else
sock = ::Socket.unix(options[:socket] || "/tmp/mysql.sock")
end
Expand All @@ -64,18 +68,17 @@ def initialize(sock, options)
@sock = OpenSSL::SSL::SSLSocket.new(sock, ssl_context(options))
@sock.connect
@select_sock = [@sock]
@options = options
end

MAX_BYTES = 32768

def recv(buffer)
loop do
result = @sock.read_nonblock(MAX_BYTES, buffer, exception: false)
result = @sock.read_nonblock(Nocturne::Socket::MAX_BYTES, buffer, exception: false)

if :wait_readable == result # standard:disable Style/YodaCondition
IO.select(@select_sock)
elsif :wait_writable == result # standard:disable Style/YodaCondition
IO.select(nil, @select_sock)
if :wait_readable == result
IO.select(@select_sock, nil, nil, @options[:read_timeout]) || raise(TimeoutError)
elsif :wait_writable == result
IO.select(nil, @select_sock, nil, @options[:write_timeout]) || raise(TimeoutError)
else
return result
end
Expand All @@ -86,10 +89,10 @@ def sendmsg(data)
loop do
result = @sock.write_nonblock(data, exception: false)

if :wait_readable == result # standard:disable Style/YodaCondition
IO.select(@select_sock)
elsif :wait_writable == result # standard:disable Style/YodaCondition
IO.select(nil, @select_sock)
if :wait_readable == result
IO.select(@select_sock, nil, nil, @options[:read_timeout]) || raise(TimeoutError)
elsif :wait_writable == result
IO.select(nil, @select_sock, nil, @options[:write_timeout]) || raise(TimeoutError)
else
return result
end
Expand Down
Loading

0 comments on commit d81b4e4

Please sign in to comment.