Skip to content

Commit

Permalink
Add abstraction Payload to use in Publish
Browse files Browse the repository at this point in the history
  • Loading branch information
spuun committed Oct 23, 2024
1 parent eb66e5b commit a6edb07
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 13 deletions.
2 changes: 1 addition & 1 deletion spec/io_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ describe MQTT::Protocol::IO do
mio = IO::Memory.new
io = MQTT::Protocol::IO.new(mio)

io.write_bytes_raw bytes
io.write bytes
mio.rewind

res = Bytes.new(3)
Expand Down
18 changes: 9 additions & 9 deletions spec/packets_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,13 @@ describe MQTT::Protocol::Packet do
io = MQTT::Protocol::IO.new(mio)

topic = "a/b/c"
payload = "foobar and barfoo".to_slice
payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice)
remaining_length = topic.bytesize + payload.size + 2 # 2 = sizeof topic len

io.write_byte 0b00110000u8
io.write_remaining_length remaining_length
io.write_string topic
io.write_bytes_raw payload
io.write_bytes payload

mio.rewind

Expand All @@ -282,7 +282,7 @@ describe MQTT::Protocol::Packet do
io.write_byte 0b00111000u8
io.write_remaining_length remaining_length
io.write_string topic
io.write_bytes_raw payload
io.write payload

mio.rewind

Expand All @@ -302,7 +302,7 @@ describe MQTT::Protocol::Packet do
io.write_byte 0b00111000u8
io.write_remaining_length remaining_length
io.write_string topic
io.write_bytes_raw payload
io.write payload

mio.rewind

Expand All @@ -318,7 +318,7 @@ describe MQTT::Protocol::Packet do
io = MQTT::Protocol::IO.new(mio)

topic = "a/b/c"
payload = "foobar and barfoo".to_slice
payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice)
packet_id = 100u16
publish = MQTT::Protocol::Publish.new(topic, payload, packet_id, false, 1, false)
publish.to_io(io)
Expand All @@ -335,7 +335,7 @@ describe MQTT::Protocol::Packet do

it "raises error if dup is set for QoS 0 messages" do
topic = "a/b/c"
payload = "foobar and barfoo".to_slice
payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice)
packet_id = 100u16
expect_raises(ArgumentError) do
MQTT::Protocol::Publish.new(topic, payload, packet_id, true, 0, false)
Expand All @@ -347,7 +347,7 @@ describe MQTT::Protocol::Packet do
io = MQTT::Protocol::IO.new(mio)

topic = "a/b/c"
payload = "foobar and barfoo".to_slice
payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice)
packet_id = 100u16
publish = MQTT::Protocol::Publish.new(topic, payload, packet_id, false, 0, false)
publish.to_io(io)
Expand All @@ -365,7 +365,7 @@ describe MQTT::Protocol::Packet do
describe "#initialize" do
it "raises an error if QoS is 3" do
topic = "a/b/c"
payload = "foobar and barfoo".to_slice
payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice)
packet_id = 100u16
expect_raises(ArgumentError) do
MQTT::Protocol::Publish.new(topic, payload, packet_id, false, 3, false)
Expand All @@ -375,7 +375,7 @@ describe MQTT::Protocol::Packet do
describe "with wildcard in topic" do
it "should raise ArguementError" do
topic = "a/#"
payload = "foobar and barfoo".to_slice
payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice)
packet_id = 100u16

expect_raises(ArgumentError) do
Expand Down
77 changes: 77 additions & 0 deletions spec/payload_spec.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
require "./spec_helper"

describe MQTT::Protocol::Payload do
it ".new(Bytes) returns a BytesPayload" do
obj = MQTT::Protocol::Payload.new("foo".to_slice)
obj.should be_a(MQTT::Protocol::BytesPayload)
end

it ".new(IO) returns a IOPayload" do
io = IO::Memory.new
io.write "foo".to_slice
obj = MQTT::Protocol::Payload.new(io, 3)
obj.should be_a(MQTT::Protocol::IOPayload)
end

describe "#==" do
it "should return true for two BytePayload with same bytes" do
one = MQTT::Protocol::BytesPayload.new("foo".to_slice)
two = MQTT::Protocol::BytesPayload.new("foo".to_slice)

(one == two).should be_true
end

it "should return false for two BytePayload with different bytes" do
one = MQTT::Protocol::BytesPayload.new("foo".to_slice)
two = MQTT::Protocol::BytesPayload.new("bar".to_slice)

(one == two).should be_false
end

it "should return true for two IOPayload with same content" do
io_one = IO::Memory.new("foo".to_slice)
io_two = IO::Memory.new("foo".to_slice)

io_one.rewind
io_two.rewind

one = MQTT::Protocol::IOPayload.new(io_one, 3)
two = MQTT::Protocol::IOPayload.new(io_two, 3)

(one == two).should be_true
end

it "should return false for two IOPayload with different content" do
io_one = IO::Memory.new("foo".to_slice)
io_two = IO::Memory.new("bar".to_slice)

io_one.rewind
io_two.rewind

one = MQTT::Protocol::IOPayload.new(io_one, 3)
two = MQTT::Protocol::IOPayload.new(io_two, 3)

(one == two).should be_false
end

it "should return true for one BytesPayload and one IOPayload with same content" do
io_two = IO::Memory.new("foo".to_slice)
io_two.rewind

one = MQTT::Protocol::BytesPayload.new("foo".to_slice)
two = MQTT::Protocol::IOPayload.new(io_two, 3)

(one == two).should be_true
end

it "should return false for one BytesPayload and one IOPayload with different content" do
io_two = IO::Memory.new("bar".to_slice)
io_two.rewind

one = MQTT::Protocol::BytesPayload.new("foo".to_slice)
two = MQTT::Protocol::IOPayload.new(io_two, 3)

(one == two).should be_false
end
end
end
1 change: 1 addition & 0 deletions src/mqtt/protocol/io.cr
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ module MQTT
@io.write bytes
end

@[Deprecated("Use write_bytes instead")]
def write_bytes_raw(bytes : Bytes)
@io.write bytes
end
Expand Down
20 changes: 17 additions & 3 deletions src/mqtt/protocol/packets/publish.cr
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
require "../payload"
require "./packets"

module MQTT
Expand All @@ -8,7 +9,7 @@ module MQTT
getter topic, payload, qos, packet_id, remaining_length
getter? dup, retain

def initialize(@topic : String, @payload : Bytes, @packet_id : UInt16?, @dup : Bool, @qos : UInt8, @retain : Bool)
def initialize(@topic : String, @payload : Payload, @packet_id : UInt16?, @dup : Bool, @qos : UInt8, @retain : Bool)
raise ArgumentError.new("QoS must be 0, 1 or 2") if @qos > 2
raise ArgumentError.new("Topic cannot contain wildcard") if @topic.matches?(/[#+]/)
raise ArgumentError.new("Topic must be between atleast 1 char long") if @topic.size < 1
Expand All @@ -19,6 +20,19 @@ module MQTT
@remaining_length += 2 if qos.positive? # packet_id
end

@[Deprecated("Use Payload instead of Bytes for @payload")]
def initialize(@topic : String, payload : Bytes, @packet_id : UInt16?, @dup : Bool, @qos : UInt8, @retain : Bool)
raise ArgumentError.new("QoS must be 0, 1 or 2") if @qos > 2
raise ArgumentError.new("Topic cannot contain wildcard") if @topic.matches?(/[#+]/)
raise ArgumentError.new("Topic must be between atleast 1 char long") if @topic.size < 1
raise ArgumentError.new("Topic cannot be larger than 65535 bytes") if @topic.bytesize > 65535
raise ArgumentError.new("DUP must be 0 for QoS 0 messages") if dup? && qos.zero?
@payload = BytesPayload.new(payload)
@remaining_length = 0
@remaining_length += (2 + topic.bytesize) + payload.bytesize
@remaining_length += 2 if qos.positive? # packet_id
end

def self.from_io(io : MQTT::Protocol::IO, flags : Flags, remaining_length : UInt32)
dup = flags.bit(3) > 0
retain = flags.bit(0) > 0
Expand All @@ -32,7 +46,7 @@ module MQTT
else
decode_assert dup == false, "DUP must be 0 for QoS 0 messages"
end
payload = io.read_bytes(remaining_length.to_u16)
payload = IOPayload.new(io, remaining_length.to_i32)
self.new(topic, payload, packet_id, dup, qos, retain)
rescue ex : ArgumentError
raise MQTT::Protocol::Error::PacketDecode.new(ex.message)
Expand All @@ -47,7 +61,7 @@ module MQTT
io.write_remaining_length remaining_length
io.write_string topic
io.write_int packet_id.not_nil! if qos.positive?
io.write_bytes_raw(payload)
io.write_bytes payload
end
end
end
Expand Down
82 changes: 82 additions & 0 deletions src/mqtt/protocol/payload.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
require "./io"

module MQTT
module Protocol
abstract struct Payload
def self.new(bytes : Bytes)
BytesPayload.new(bytes)
end

def self.new(io : ::IO, bytesize : Int32)
IOPayload.new(MQTT::Protocol::IO.new(io), bytesize)
end

def self.new(io : MQTT::Protocol::IO, bytesize : Int32)
IOPayload.new(io, bytesize)
end

def size
bytesize
end

abstract def bytesize : Int32
abstract def to_slice : Bytes
abstract def to_io(io, format : ::IO::ByteFormat = IO::ByteFormat::SystemEndian)

def ==(other)
return false unless other.is_a?(Payload)
to_slice == other.to_slice
end
end

struct BytesPayload < Payload
def initialize(@bytes : Bytes)
end

def bytesize : Int32
@bytes.bytesize
end

def to_slice : Bytes
@bytes
end

def to_io(io, format : ::IO::ByteFormat = IO::ByteFormat::SystemEndian)
io.write @bytes
end
end

struct IOPayload < Payload
getter bytesize : Int32

@data : Bytes? = nil

def initialize(@io : MQTT::Protocol::IO, @bytesize : Int32)
end

def initialize(io : ::IO, @bytesize : Int32)
@io = MQTT::Protocol::IO.new(io)
end

def to_slice : Bytes
if peeked = @io.peek.try &.[0, bytesize]?
return peeked
end
return @data || begin
data = Bytes.new(bytesize)
@io.read(data)
data
end

Check notice on line 69 in src/mqtt/protocol/payload.cr

View workflow job for this annotation

GitHub Actions / Ameba

Style/RedundantReturn

Redundant `return` detected
Raw output
> return @data || begin
  ^
end

def to_io(io, format : ::IO::ByteFormat = IO::ByteFormat::SystemEndian)
if data = @data
io.write data
else
copied = ::IO.copy(@io, io, bytesize)
raise "Failed to copy payload" if copied != bytesize
end
end
end
end
end

0 comments on commit a6edb07

Please sign in to comment.