From a6edb0752796bee393d5289d811865043537910a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20B=C3=B6rjesson?= Date: Wed, 23 Oct 2024 14:03:39 +0200 Subject: [PATCH] Add abstraction Payload to use in Publish --- spec/io_spec.cr | 2 +- spec/packets_spec.cr | 18 +++--- spec/payload_spec.cr | 77 ++++++++++++++++++++++++++ src/mqtt/protocol/io.cr | 1 + src/mqtt/protocol/packets/publish.cr | 20 ++++++- src/mqtt/protocol/payload.cr | 82 ++++++++++++++++++++++++++++ 6 files changed, 187 insertions(+), 13 deletions(-) create mode 100644 spec/payload_spec.cr create mode 100644 src/mqtt/protocol/payload.cr diff --git a/spec/io_spec.cr b/spec/io_spec.cr index 55be705..041f851 100644 --- a/spec/io_spec.cr +++ b/spec/io_spec.cr @@ -84,7 +84,7 @@ describe MQTT::Protocol::IO do mio = IO::Memory.new io = MQTT::Protocol::IO.new(mio) - io.write_bytes_raw bytes + io.write bytes mio.rewind res = Bytes.new(3) diff --git a/spec/packets_spec.cr b/spec/packets_spec.cr index 16d0199..e1e9dbe 100644 --- a/spec/packets_spec.cr +++ b/spec/packets_spec.cr @@ -254,13 +254,13 @@ describe MQTT::Protocol::Packet do io = MQTT::Protocol::IO.new(mio) topic = "a/b/c" - payload = "foobar and barfoo".to_slice + payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice) remaining_length = topic.bytesize + payload.size + 2 # 2 = sizeof topic len io.write_byte 0b00110000u8 io.write_remaining_length remaining_length io.write_string topic - io.write_bytes_raw payload + io.write_bytes payload mio.rewind @@ -282,7 +282,7 @@ describe MQTT::Protocol::Packet do io.write_byte 0b00111000u8 io.write_remaining_length remaining_length io.write_string topic - io.write_bytes_raw payload + io.write payload mio.rewind @@ -302,7 +302,7 @@ describe MQTT::Protocol::Packet do io.write_byte 0b00111000u8 io.write_remaining_length remaining_length io.write_string topic - io.write_bytes_raw payload + io.write payload mio.rewind @@ -318,7 +318,7 @@ describe MQTT::Protocol::Packet do io = MQTT::Protocol::IO.new(mio) topic = "a/b/c" - payload = "foobar and barfoo".to_slice + payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice) packet_id = 100u16 publish = MQTT::Protocol::Publish.new(topic, payload, packet_id, false, 1, false) publish.to_io(io) @@ -335,7 +335,7 @@ describe MQTT::Protocol::Packet do it "raises error if dup is set for QoS 0 messages" do topic = "a/b/c" - payload = "foobar and barfoo".to_slice + payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice) packet_id = 100u16 expect_raises(ArgumentError) do MQTT::Protocol::Publish.new(topic, payload, packet_id, true, 0, false) @@ -347,7 +347,7 @@ describe MQTT::Protocol::Packet do io = MQTT::Protocol::IO.new(mio) topic = "a/b/c" - payload = "foobar and barfoo".to_slice + payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice) packet_id = 100u16 publish = MQTT::Protocol::Publish.new(topic, payload, packet_id, false, 0, false) publish.to_io(io) @@ -365,7 +365,7 @@ describe MQTT::Protocol::Packet do describe "#initialize" do it "raises an error if QoS is 3" do topic = "a/b/c" - payload = "foobar and barfoo".to_slice + payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice) packet_id = 100u16 expect_raises(ArgumentError) do MQTT::Protocol::Publish.new(topic, payload, packet_id, false, 3, false) @@ -375,7 +375,7 @@ describe MQTT::Protocol::Packet do describe "with wildcard in topic" do it "should raise ArguementError" do topic = "a/#" - payload = "foobar and barfoo".to_slice + payload = MQTT::Protocol::Payload.new("foobar and barfoo".to_slice) packet_id = 100u16 expect_raises(ArgumentError) do diff --git a/spec/payload_spec.cr b/spec/payload_spec.cr new file mode 100644 index 0000000..c25d48c --- /dev/null +++ b/spec/payload_spec.cr @@ -0,0 +1,77 @@ +require "./spec_helper" + +describe MQTT::Protocol::Payload do + it ".new(Bytes) returns a BytesPayload" do + obj = MQTT::Protocol::Payload.new("foo".to_slice) + obj.should be_a(MQTT::Protocol::BytesPayload) + end + + it ".new(IO) returns a IOPayload" do + io = IO::Memory.new + io.write "foo".to_slice + obj = MQTT::Protocol::Payload.new(io, 3) + obj.should be_a(MQTT::Protocol::IOPayload) + end + + describe "#==" do + it "should return true for two BytePayload with same bytes" do + one = MQTT::Protocol::BytesPayload.new("foo".to_slice) + two = MQTT::Protocol::BytesPayload.new("foo".to_slice) + + (one == two).should be_true + end + + it "should return false for two BytePayload with different bytes" do + one = MQTT::Protocol::BytesPayload.new("foo".to_slice) + two = MQTT::Protocol::BytesPayload.new("bar".to_slice) + + (one == two).should be_false + end + + it "should return true for two IOPayload with same content" do + io_one = IO::Memory.new("foo".to_slice) + io_two = IO::Memory.new("foo".to_slice) + + io_one.rewind + io_two.rewind + + one = MQTT::Protocol::IOPayload.new(io_one, 3) + two = MQTT::Protocol::IOPayload.new(io_two, 3) + + (one == two).should be_true + end + + it "should return false for two IOPayload with different content" do + io_one = IO::Memory.new("foo".to_slice) + io_two = IO::Memory.new("bar".to_slice) + + io_one.rewind + io_two.rewind + + one = MQTT::Protocol::IOPayload.new(io_one, 3) + two = MQTT::Protocol::IOPayload.new(io_two, 3) + + (one == two).should be_false + end + + it "should return true for one BytesPayload and one IOPayload with same content" do + io_two = IO::Memory.new("foo".to_slice) + io_two.rewind + + one = MQTT::Protocol::BytesPayload.new("foo".to_slice) + two = MQTT::Protocol::IOPayload.new(io_two, 3) + + (one == two).should be_true + end + + it "should return false for one BytesPayload and one IOPayload with different content" do + io_two = IO::Memory.new("bar".to_slice) + io_two.rewind + + one = MQTT::Protocol::BytesPayload.new("foo".to_slice) + two = MQTT::Protocol::IOPayload.new(io_two, 3) + + (one == two).should be_false + end + end +end diff --git a/src/mqtt/protocol/io.cr b/src/mqtt/protocol/io.cr index 2cd204c..c225ac1 100644 --- a/src/mqtt/protocol/io.cr +++ b/src/mqtt/protocol/io.cr @@ -68,6 +68,7 @@ module MQTT @io.write bytes end + @[Deprecated("Use write_bytes instead")] def write_bytes_raw(bytes : Bytes) @io.write bytes end diff --git a/src/mqtt/protocol/packets/publish.cr b/src/mqtt/protocol/packets/publish.cr index d8bd06d..01ba88b 100644 --- a/src/mqtt/protocol/packets/publish.cr +++ b/src/mqtt/protocol/packets/publish.cr @@ -1,3 +1,4 @@ +require "../payload" require "./packets" module MQTT @@ -8,7 +9,7 @@ module MQTT getter topic, payload, qos, packet_id, remaining_length getter? dup, retain - def initialize(@topic : String, @payload : Bytes, @packet_id : UInt16?, @dup : Bool, @qos : UInt8, @retain : Bool) + def initialize(@topic : String, @payload : Payload, @packet_id : UInt16?, @dup : Bool, @qos : UInt8, @retain : Bool) raise ArgumentError.new("QoS must be 0, 1 or 2") if @qos > 2 raise ArgumentError.new("Topic cannot contain wildcard") if @topic.matches?(/[#+]/) raise ArgumentError.new("Topic must be between atleast 1 char long") if @topic.size < 1 @@ -19,6 +20,19 @@ module MQTT @remaining_length += 2 if qos.positive? # packet_id end + @[Deprecated("Use Payload instead of Bytes for @payload")] + def initialize(@topic : String, payload : Bytes, @packet_id : UInt16?, @dup : Bool, @qos : UInt8, @retain : Bool) + raise ArgumentError.new("QoS must be 0, 1 or 2") if @qos > 2 + raise ArgumentError.new("Topic cannot contain wildcard") if @topic.matches?(/[#+]/) + raise ArgumentError.new("Topic must be between atleast 1 char long") if @topic.size < 1 + raise ArgumentError.new("Topic cannot be larger than 65535 bytes") if @topic.bytesize > 65535 + raise ArgumentError.new("DUP must be 0 for QoS 0 messages") if dup? && qos.zero? + @payload = BytesPayload.new(payload) + @remaining_length = 0 + @remaining_length += (2 + topic.bytesize) + payload.bytesize + @remaining_length += 2 if qos.positive? # packet_id + end + def self.from_io(io : MQTT::Protocol::IO, flags : Flags, remaining_length : UInt32) dup = flags.bit(3) > 0 retain = flags.bit(0) > 0 @@ -32,7 +46,7 @@ module MQTT else decode_assert dup == false, "DUP must be 0 for QoS 0 messages" end - payload = io.read_bytes(remaining_length.to_u16) + payload = IOPayload.new(io, remaining_length.to_i32) self.new(topic, payload, packet_id, dup, qos, retain) rescue ex : ArgumentError raise MQTT::Protocol::Error::PacketDecode.new(ex.message) @@ -47,7 +61,7 @@ module MQTT io.write_remaining_length remaining_length io.write_string topic io.write_int packet_id.not_nil! if qos.positive? - io.write_bytes_raw(payload) + io.write_bytes payload end end end diff --git a/src/mqtt/protocol/payload.cr b/src/mqtt/protocol/payload.cr new file mode 100644 index 0000000..aa6737c --- /dev/null +++ b/src/mqtt/protocol/payload.cr @@ -0,0 +1,82 @@ +require "./io" + +module MQTT + module Protocol + abstract struct Payload + def self.new(bytes : Bytes) + BytesPayload.new(bytes) + end + + def self.new(io : ::IO, bytesize : Int32) + IOPayload.new(MQTT::Protocol::IO.new(io), bytesize) + end + + def self.new(io : MQTT::Protocol::IO, bytesize : Int32) + IOPayload.new(io, bytesize) + end + + def size + bytesize + end + + abstract def bytesize : Int32 + abstract def to_slice : Bytes + abstract def to_io(io, format : ::IO::ByteFormat = IO::ByteFormat::SystemEndian) + + def ==(other) + return false unless other.is_a?(Payload) + to_slice == other.to_slice + end + end + + struct BytesPayload < Payload + def initialize(@bytes : Bytes) + end + + def bytesize : Int32 + @bytes.bytesize + end + + def to_slice : Bytes + @bytes + end + + def to_io(io, format : ::IO::ByteFormat = IO::ByteFormat::SystemEndian) + io.write @bytes + end + end + + struct IOPayload < Payload + getter bytesize : Int32 + + @data : Bytes? = nil + + def initialize(@io : MQTT::Protocol::IO, @bytesize : Int32) + end + + def initialize(io : ::IO, @bytesize : Int32) + @io = MQTT::Protocol::IO.new(io) + end + + def to_slice : Bytes + if peeked = @io.peek.try &.[0, bytesize]? + return peeked + end + return @data || begin + data = Bytes.new(bytesize) + @io.read(data) + data + end + end + + def to_io(io, format : ::IO::ByteFormat = IO::ByteFormat::SystemEndian) + if data = @data + io.write data + else + copied = ::IO.copy(@io, io, bytesize) + raise "Failed to copy payload" if copied != bytesize + end + end + end + end +end