From ed808defe4dbba859244e1d4b3a579930ff9a240 Mon Sep 17 00:00:00 2001 From: Chinmay Date: Thu, 9 Nov 2023 15:33:44 +0530 Subject: [PATCH] Non Parser related changes removal and minor code fixes Signed-off-by: Chinmay --- src/stirling/binaries/stirling_wrapper.cc | 2 +- .../bcc_bpf/protocol_inference.h | 122 ----------- .../socket_tracer/protocols/BUILD.bazel | 1 + .../socket_tracer/protocols/mqtt/BUILD.bazel | 1 - .../socket_tracer/protocols/mqtt/parse.cc | 205 +++++++++--------- .../socket_tracer/protocols/mqtt/parse.h | 2 +- .../protocols/mqtt/parse_test.cc | 40 ++-- .../socket_tracer/protocols/mqtt/types.h | 47 +++- .../socket_tracer/protocols/stitchers.h | 1 + .../socket_tracer/protocols/types.h | 4 +- .../socket_tracer/socket_trace_connector.cc | 36 +++ .../socket_tracer/socket_trace_connector.h | 4 +- 12 files changed, 209 insertions(+), 256 deletions(-) diff --git a/src/stirling/binaries/stirling_wrapper.cc b/src/stirling/binaries/stirling_wrapper.cc index 8bb8ac8bb63..485b460e150 100644 --- a/src/stirling/binaries/stirling_wrapper.cc +++ b/src/stirling/binaries/stirling_wrapper.cc @@ -62,7 +62,7 @@ DEFINE_string(trace, "", "Dynamic trace to deploy. Either (1) the path to a file containing PxL or IR trace " "spec, or (2) : for full-function tracing."); DEFINE_string(print_record_batches, - "http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events", + "http_events,mysql_events,pgsql_events,redis_events,cql_events,dns_events,mqtt_events", "Comma-separated list of tables to print."); DEFINE_bool(init_only, false, "If true, only runs the init phase and exits. For testing."); DEFINE_int32(timeout_secs, -1, diff --git a/src/stirling/source_connectors/socket_tracer/bcc_bpf/protocol_inference.h b/src/stirling/source_connectors/socket_tracer/bcc_bpf/protocol_inference.h index fa70e134545..36aedc1d5d2 100644 --- a/src/stirling/source_connectors/socket_tracer/bcc_bpf/protocol_inference.h +++ b/src/stirling/source_connectors/socket_tracer/bcc_bpf/protocol_inference.h @@ -683,124 +683,6 @@ static __inline enum message_type_t infer_nats_message(const char* buf, size_t c return kUnknown; } -static __inline enum message_type_t infer_mqtt_message(const char* buf, size_t count) { - static const uint8_t kConnect = 0x10; - static const uint8_t kConnack = 0x20; - static const uint8_t kPublish = 0x30; - static const uint8_t kPuback = 0x40; - static const uint8_t kPubrec = 0x50; - static const uint8_t kPubrel = 0x62; - static const uint8_t kPubcomp = 0x70; - static const uint8_t kSubscribe = 0x82; - static const uint8_t kSuback = 0x90; - static const uint8_t kUnsubscribe = 0xa2; - static const uint8_t kUnsuback = 0xb0; - static const uint8_t kPingreq = 0xc0; - static const uint8_t kPingresp = 0xd0; - static const uint8_t kDisconnect = 0xe0; - - static const uint8_t kPublishQos1 = 0x32; - static const uint8_t kPublishQos2 =0x34; - static const uint8_t kDupMask = 0x08; - static const uint8_t kRetainMask = 0x01; - - static const uint8_t kMinFixedHeaderLength = 2; - - const uint8_t* ubuf = (const uint8_t*)buf; - - // Minimum Size of Fixed Header in MQTT is 2 - if (count < kMinFixedHeaderLength) { - return kUnknown; - } - - // Remaining length can be ranging from 2 to 4 bytes, decoding the variable length remaining length field - int byte_counter = 1; - int multiplier = 1; - int decoded_remaining_length = 0; - uint8_t encoded_byte; - - do { - // message size cannot be less than what the remaining length needs based on variable encoding - if (count <= (size_t)byte_counter) { - return kUnknown; - } - - encoded_byte = ubuf[byte_counter]; - decoded_remaining_length += (encoded_byte & 127) * multiplier; - // size of the remaining length cannot be above 4 bytes - if (multiplier > 128*128*128) { - return kUnknown; - } - multiplier *= 128; - byte_counter += 1; - } while ((encoded_byte & 128) != 0); - - int fixed_header_length = byte_counter; - size_t actual_remaining_length = count - (size_t)fixed_header_length; - - if (actual_remaining_length != decoded_remaining_length) { - return kUnknown; - } - - switch (ubuf[0]) - { - case kConnect: - return kRequest; - case kConnack: - return kRequest; - case kPublish: - return kRequest; - case kPublishQos1: - return kRequest; - case kPublishQos2: - return kRequest; - case kPublish | kDupMask: - return kRequest; - case kPublishQos1 | kDupMask: - return kRequest; - case kPublishQos2 | kDupMask: - return kRequest; - case kPublish | kRetainMask: - return kRequest; - case kPublishQos1 | kRetainMask: - return kRequest; - case kPublishQos2 | kRetainMask: - return kRequest; - case kPublish | kDupMask | kRetainMask: - return kRequest; - case kPublishQos1 | kDupMask | kRetainMask: - return kRequest; - case kPublishQos2 | kDupMask | kRetainMask: - return kRequest; - case kPuback: - return kRequest; - case kPubrec: - return kRequest; - case kPubrel: - return kRequest; - case kPubcomp: - return kRequest; - case kSubscribe: - return kRequest; - case kSuback: - return kRequest; - case kUnsubscribe: - return kRequest; - case kUnsuback: - return kRequest; - case kPingreq: - return kRequest; - case kPingresp: - return kRequest; - case kDisconnect: - return kRequest; - default: - return kUnknown; - } - - return kUnknown; -} - static __inline struct protocol_message_t infer_protocol(const char* buf, size_t count, struct conn_info_t* conn_info) { struct protocol_message_t inferred_message; @@ -850,12 +732,8 @@ static __inline struct protocol_message_t infer_protocol(const char* buf, size_t } else if (ENABLE_NATS_TRACING && (inferred_message.type = infer_nats_message(buf, count)) != kUnknown) { inferred_message.protocol = kProtocolNATS; - } else if (ENABLE_MQTT_TRACING && - (inferred_message.type = infer_mqtt_message(buf, count)) != kUnknown) { - inferred_message.protocol = kProtocolMQTT; } - conn_info->prev_count = count; if (count == 4) { conn_info->prev_buf[0] = buf[0]; diff --git a/src/stirling/source_connectors/socket_tracer/protocols/BUILD.bazel b/src/stirling/source_connectors/socket_tracer/protocols/BUILD.bazel index 26a4e6d26b6..505d512dac6 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/BUILD.bazel +++ b/src/stirling/source_connectors/socket_tracer/protocols/BUILD.bazel @@ -45,5 +45,6 @@ pl_cc_library( "//src/stirling/source_connectors/socket_tracer/protocols/nats:cc_library", "//src/stirling/source_connectors/socket_tracer/protocols/pgsql:cc_library", "//src/stirling/source_connectors/socket_tracer/protocols/redis:cc_library", + "//src/stirling/source_connectors/socket_tracer/protocols/mqtt:cc_library", ], ) diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel index a131785c15d..366a2b8e035 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/BUILD.bazel @@ -35,7 +35,6 @@ pl_cc_library( ], ), deps = [ - "//src/common/json:cc_library", "//src/stirling/source_connectors/socket_tracer/protocols/common:cc_library", "//src/stirling/utils:cc_library", ], diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc index 8cdf4bb6977..25a55829255 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.cc @@ -25,10 +25,9 @@ #include "src/stirling/utils/binary_decoder.h" #include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h" -#define PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(expr, val_or) \ +#define PX_ASSIGN_OR_RETURN_ERROR(expr, val_or) \ PX_ASSIGN_OR(expr, val_or, return ParseState::kNeedsMoreData) -#define PX_ASSIGN_OR_RETURN_INVALID(expr, val_or) \ - PX_ASSIGN_OR(expr, val_or, return ParseState::kInvalid) + namespace px { namespace stirling { namespace protocols { @@ -53,6 +52,24 @@ enum class MqttControlPacketType : uint8_t { INVALID = 0xff, }; +std::unordered_map ControlPacketTypeStrings = { + {MqttControlPacketType::CONNECT, "CONNECT"}, + {MqttControlPacketType::CONNACK, "CONNACK"}, + {MqttControlPacketType::PUBLISH, "PUBLISH"}, + {MqttControlPacketType::PUBACK, "PUBACK"}, + {MqttControlPacketType::PUBREC, "PUBREC"}, + {MqttControlPacketType::PUBREL, "PUBREL"}, + {MqttControlPacketType::PUBCOMP, "PUBCOMP"}, + {MqttControlPacketType::SUBSCRIBE, "SUBSCRIBE"}, + {MqttControlPacketType::SUBACK, "SUBACK"}, + {MqttControlPacketType::UNSUBSCRIBE, "UNSUBSCRIBE"}, + {MqttControlPacketType::UNSUBACK, "UNSUBACK"}, + {MqttControlPacketType::PINGREQ, "PINGREQ"}, + {MqttControlPacketType::PINGRESP, "PINGRESP"}, + {MqttControlPacketType::DISCONNECT, "DISCONNECT"}, + {MqttControlPacketType::INVALID, "INVALID"}, +}; + enum class PropertyCode: uint8_t { PayloadFormatIndicator = 0x01, MessageExpiryInterval = 0x02, @@ -140,12 +157,12 @@ ParseState ParseProperties(Message* result, BinaryDecoder* decoder, size_t& prop uint8_t property_code; while (properties_length > 0) { // Extracting the property code - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(property_code, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(property_code, decoder->ExtractBEInt()); properties_length -= 1; switch (property_code) { case static_cast(PropertyCode::PayloadFormatIndicator): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t payload_format_indicator, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t payload_format_indicator, decoder->ExtractBEInt()); if (payload_format_indicator == 0x00) { result->properties["payload_format"] = "unspecified"; } else if (payload_format_indicator == 0x01) { @@ -157,31 +174,31 @@ ParseState ParseProperties(Message* result, BinaryDecoder* decoder, size_t& prop break; } case static_cast(PropertyCode::MessageExpiryInterval): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint32_t message_expiry_interval, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint32_t message_expiry_interval, decoder->ExtractBEInt()); result->properties["message_expiry_interval"] = std::to_string(message_expiry_interval); properties_length -= 4; break; } case static_cast(PropertyCode::ContentType): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view content_type, decoder->ExtractString(property_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view content_type, decoder->ExtractString(property_length)); result->properties["content_type"] = std::string(content_type); properties_length -= property_length; break; } case static_cast(PropertyCode::ResponseTopic): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view response_topic, decoder->ExtractString(properties_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view response_topic, decoder->ExtractString(properties_length)); result->properties["response_topic"] = std::string(response_topic); properties_length -= property_length; break; } case static_cast(PropertyCode::CorrelationData): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view correlation_data, decoder->ExtractString(property_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view correlation_data, decoder->ExtractString(property_length)); result->properties["correlation_data"] = std::string(correlation_data); properties_length -= property_length; break; @@ -190,7 +207,7 @@ ParseState ParseProperties(Message* result, BinaryDecoder* decoder, size_t& prop unsigned long subscription_id; size_t num_bytes; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(subscription_id, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(subscription_id, decoder->ExtractUVarInt()); StatusOr num_bytes_status = VariableEncodingNumBytes(subscription_id); if (!num_bytes_status.ok()) { return ParseState::kInvalid; @@ -202,121 +219,121 @@ ParseState ParseProperties(Message* result, BinaryDecoder* decoder, size_t& prop break; } case static_cast(PropertyCode::SessionExpiryInterval): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint32_t session_expiry_interval, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint32_t session_expiry_interval, decoder->ExtractBEInt()); result->properties["session_expiry_interval"] = std::to_string(session_expiry_interval); properties_length -= 4; break; } case static_cast(PropertyCode::AssignedClientIdentifier): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view assigned_client_identifier, decoder->ExtractString(property_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view assigned_client_identifier, decoder->ExtractString(property_length)); result->properties["assigned_client_identifier"] = std::string(assigned_client_identifier); properties_length -= property_length; break; } case static_cast(PropertyCode::ServerKeepAlive): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t server_keep_alive, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t server_keep_alive, decoder->ExtractBEInt()); result->properties["server_keep_alive"] = std::to_string(server_keep_alive); properties_length -= 2; break; } case static_cast(PropertyCode::AuthenticationMethod): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view auth_method, decoder->ExtractString(property_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view auth_method, decoder->ExtractString(property_length)); result->properties["auth_method"] = std::string(auth_method); properties_length -= property_length; break; } case static_cast(PropertyCode::AuthenticationData): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view auth_data, decoder->ExtractString(property_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view auth_data, decoder->ExtractString(property_length)); result->properties["auth_data"] = std::string(auth_data); properties_length -= property_length; break; } case static_cast(PropertyCode::RequestProblemInformation): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t request_problem_information, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t request_problem_information, decoder->ExtractBEInt()); result->properties["request_problem_information"] = std::to_string(request_problem_information); properties_length -= 1; break; } case static_cast(PropertyCode::WillDelayInterval): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint32_t will_delay_interval, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint32_t will_delay_interval, decoder->ExtractBEInt()); result->properties["will_delay_interval"] = std::to_string(will_delay_interval); properties_length -= 4; break; } case static_cast(PropertyCode::RequestResponseInformation): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t request_response_information, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t request_response_information, decoder->ExtractBEInt()); result->properties["request_response_information"] = std::to_string(request_response_information); properties_length -= 1; break; } case static_cast(PropertyCode::ResponseInformation): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view response_information, decoder->ExtractString(properties_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view response_information, decoder->ExtractString(properties_length)); result->properties["response_information"] = std::string(response_information); properties_length -= property_length; break; } case static_cast(PropertyCode::ServerReference): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view server_reference, decoder->ExtractString(property_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view server_reference, decoder->ExtractString(property_length)); result->properties["server_reference"] = std::string(server_reference); properties_length -= property_length; break; } case static_cast(PropertyCode::ReasonString): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t property_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t property_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view reason_string, decoder->ExtractString(property_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view reason_string, decoder->ExtractString(property_length)); result->properties["reason_string"] = std::string(reason_string); properties_length -= property_length; break; } case static_cast(PropertyCode::ReceiveMaximum): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t receive_maximum, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t receive_maximum, decoder->ExtractBEInt()); result->properties["receive_maximum"] = std::to_string(receive_maximum); properties_length -= 2; break; } case static_cast(PropertyCode::TopicAliasMaximum): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t topic_alias_maximum, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t topic_alias_maximum, decoder->ExtractBEInt()); result->properties["topic_alias_maximum"] = std::to_string(topic_alias_maximum); properties_length -= 2; break; } case static_cast(PropertyCode::TopicAlias): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t topic_alias, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t topic_alias, decoder->ExtractBEInt()); result->properties["topic_alias"] = std::to_string(topic_alias); properties_length -= 2; break; } case static_cast(PropertyCode::MaximumQos): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t topic_alias, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t topic_alias, decoder->ExtractBEInt()); result->properties["topic_alias"] = std::to_string(topic_alias); properties_length -= 2; break; } case static_cast(PropertyCode::RetainAvailable): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t retain_available, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t retain_available, decoder->ExtractBEInt()); result->properties["retain_available"] = std::to_string(retain_available); properties_length -= 1; break; } case static_cast(PropertyCode::UserProperty): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t key_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t key_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view key, decoder->ExtractString(key_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view key, decoder->ExtractString(key_length)); properties_length -= key_length; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t value_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t value_length, decoder->ExtractBEInt()); properties_length -= 2; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view value, decoder->ExtractString(value_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view value, decoder->ExtractString(value_length)); properties_length -= value_length; // For multiple user properties present, append to string if user property already present if (result->properties.find("user-properties") == result->properties.end()) { @@ -327,25 +344,25 @@ ParseState ParseProperties(Message* result, BinaryDecoder* decoder, size_t& prop break; } case static_cast(PropertyCode::MaximumPacketSize): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint32_t maximum_packet_size, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint32_t maximum_packet_size, decoder->ExtractBEInt()); result->properties["maximum_packet_size"] = std::to_string(maximum_packet_size); properties_length -= 4; break; } case static_cast(PropertyCode::WildcardSubscriptionAvailable): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t wildcard_subscription_available, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t wildcard_subscription_available, decoder->ExtractBEInt()); result->properties["retain_available"] = (wildcard_subscription_available == 1)?"true":"false"; properties_length -= 1; break; } case static_cast(PropertyCode::SubscriptionIdentifiersAvailable): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t subscription_id_available, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t subscription_id_available, decoder->ExtractBEInt()); result->properties["subscription_id_available"] = (subscription_id_available == 1)?"true":"false"; properties_length -= 1; break; } case static_cast(PropertyCode::SharedSubscriptionAvailable): { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t shared_subscription_available, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t shared_subscription_available, decoder->ExtractBEInt()); result->properties["subscription_id_available"] = (shared_subscription_available == 1)?"true":"false"; properties_length -= 1; break; @@ -360,13 +377,13 @@ ParseState ParseProperties(Message* result, BinaryDecoder* decoder, size_t& prop ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttControlPacketType& control_packet_type) { switch (control_packet_type) { case MqttControlPacketType::CONNECT: { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t protocol_name_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view protocol_name, decoder->ExtractString(protocol_name_length)); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t protocol_name_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view protocol_name, decoder->ExtractString(protocol_name_length)); CTX_DCHECK(protocol_name == "MQTT"); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t protocol_version, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t protocol_version, decoder->ExtractBEInt()); CTX_DCHECK(protocol_version == 5); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t connect_flags, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t connect_flags, decoder->ExtractBEInt()); result->header_fields["username_flag"] = connect_flags >> 7; result->header_fields["password_flag"] = (connect_flags >> 6) & 0x1; result->header_fields["will_retain"] = (connect_flags >> 5) & 0x1; @@ -374,11 +391,11 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont result->header_fields["will_flag"] = (connect_flags >> 2) & 0x1; result->header_fields["clean_start"] = (connect_flags >> 1) & 0x1; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(result->header_fields["keep_alive"], decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(result->header_fields["keep_alive"], decoder->ExtractBEInt()); size_t properties_length; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(properties_length, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(properties_length, decoder->ExtractUVarInt()); if (!VariableEncodingNumBytes(properties_length).ok()) { return ParseState::kInvalid; } @@ -386,14 +403,14 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont return ParseProperties(result, decoder, properties_length); } case MqttControlPacketType::CONNACK: { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t connack_flags, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(result->header_fields["reason_code"], decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t connack_flags, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(result->header_fields["reason_code"], decoder->ExtractBEInt()); result->header_fields["session_present"] = connack_flags; size_t properties_length; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(properties_length, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(properties_length, decoder->ExtractUVarInt()); if (!VariableEncodingNumBytes(properties_length).ok()) { return ParseState::kInvalid; } @@ -401,8 +418,8 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont return ParseProperties(result, decoder, properties_length); } case MqttControlPacketType::PUBLISH: { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t topic_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view topic_name, decoder->ExtractString(topic_length)); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t topic_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view topic_name, decoder->ExtractString(topic_length)); result->payload["topic_name"] = std::string(topic_name); // Storing variable header length for use in payload length calculation @@ -413,12 +430,12 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont return ParseState::kInvalid; } if (result->header_fields["qos"] != 0) { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(result->header_fields["packet_identifier"], decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(result->header_fields["packet_identifier"], decoder->ExtractBEInt()); result->header_fields["variable_header_length"] += 2; } size_t properties_length, num_bytes; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(properties_length, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(properties_length, decoder->ExtractUVarInt()); StatusOr num_bytes_status = VariableEncodingNumBytes(properties_length); if (!num_bytes_status.ok()) { return ParseState::kInvalid; @@ -433,17 +450,17 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont case MqttControlPacketType::PUBREC: case MqttControlPacketType::PUBREL: case MqttControlPacketType::PUBCOMP: { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(result->header_fields["packet_identifier"], decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(result->header_fields["packet_identifier"], decoder->ExtractBEInt()); if(result->header_fields.find("remaining_length") == result->header_fields.end()) { return ParseState::kInvalid; } if (result->header_fields["remaining_length"] >= 3) { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(result->header_fields["reason_code"], decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(result->header_fields["reason_code"], decoder->ExtractBEInt()); } if (result->header_fields["remaining_length"] >= 4) { size_t properties_length; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(properties_length, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(properties_length, decoder->ExtractUVarInt()); if (!VariableEncodingNumBytes(properties_length).ok()) { return ParseState::kInvalid; } @@ -456,12 +473,12 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont case MqttControlPacketType::SUBACK: case MqttControlPacketType::UNSUBSCRIBE: case MqttControlPacketType::UNSUBACK: { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(result->header_fields["packet_identifier"], decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(result->header_fields["packet_identifier"], decoder->ExtractBEInt()); // Storing variable header length for use in payload length calculation result->header_fields["variable_header_length"] = 2; size_t properties_length, num_bytes; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(properties_length, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(properties_length, decoder->ExtractUVarInt()); StatusOr num_bytes_status = VariableEncodingNumBytes(properties_length); if (!num_bytes_status.ok()) { return ParseState::kInvalid; @@ -472,12 +489,12 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont return ParseProperties(result, decoder, properties_length); } case MqttControlPacketType::DISCONNECT: { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(result->header_fields["reason_code"], decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(result->header_fields["reason_code"], decoder->ExtractBEInt()); if (result->header_fields["remaining_length"] > 1) { size_t properties_length; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(properties_length, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(properties_length, decoder->ExtractUVarInt()); if (!VariableEncodingNumBytes(properties_length).ok()) { return ParseState::kInvalid; } @@ -494,14 +511,14 @@ ParseState ParseVariableHeader(Message* result, BinaryDecoder* decoder, MqttCont ParseState ParsePayload(Message* result, BinaryDecoder* decoder, MqttControlPacketType& control_packet_type) { switch (control_packet_type) { case MqttControlPacketType::CONNECT: { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint16_t client_id_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view client_id, decoder->ExtractString(client_id_length)); + PX_ASSIGN_OR_RETURN_ERROR(uint16_t client_id_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view client_id, decoder->ExtractString(client_id_length)); result->payload["client_id"] = std::string(client_id); if (result->header_fields["will_flag"]) { size_t will_properties_length, will_topic_length, will_payload_length; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(will_properties_length, decoder->ExtractUVarInt()); + PX_ASSIGN_OR_RETURN_ERROR(will_properties_length, decoder->ExtractUVarInt()); if (!VariableEncodingNumBytes(will_properties_length).ok()) { return ParseState::kInvalid; } @@ -510,24 +527,24 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder, MqttControlPack return ParseState::kInvalid; } - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(will_topic_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view will_topic, decoder->ExtractString(will_topic_length)); + PX_ASSIGN_OR_RETURN_ERROR(will_topic_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view will_topic, decoder->ExtractString(will_topic_length)); result->payload["will_topic"] = std::string(will_topic); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(will_payload_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view will_payload, decoder->ExtractString(will_payload_length)); + PX_ASSIGN_OR_RETURN_ERROR(will_payload_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view will_payload, decoder->ExtractString(will_payload_length)); result->payload["will_payload"] = std::string(will_payload); } if (result->header_fields["username_flag"]) { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(size_t username_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view username, decoder->ExtractString(username_length)); + PX_ASSIGN_OR_RETURN_ERROR(size_t username_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view username, decoder->ExtractString(username_length)); result->payload["username"] = std::string(username); } if (result->header_fields["password_flag"]) { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(size_t password_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::ignore, decoder->ExtractString(password_length)); + PX_ASSIGN_OR_RETURN_ERROR(size_t password_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::ignore, decoder->ExtractString(password_length)); } return ParseState::kSuccess; @@ -540,7 +557,7 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder, MqttControlPack return ParseState::kInvalid; } size_t payload_length = result->header_fields["remaining_length"] - result->header_fields["variable_header_length"]; - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view payload, decoder->ExtractString(payload_length)); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view payload, decoder->ExtractString(payload_length)); result->payload["publish_message"] = std::string(payload); return ParseState::kSuccess; } @@ -563,14 +580,14 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder, MqttControlPack result->payload["subscription_options"] = ""; payload_length = result->header_fields["remaining_length"] - result->header_fields["variable_header_length"]; while (payload_length > 0) { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(topic_filter_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view topic_filter, decoder->ExtractString(topic_filter_length)); + PX_ASSIGN_OR_RETURN_ERROR(topic_filter_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view topic_filter, decoder->ExtractString(topic_filter_length)); if (result->payload["topic_filter"].empty()) { result->payload["topic_filter"] += std::string(topic_filter); } else { result->payload["topic_filter"] += ", " + std::string(topic_filter); } - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(subscription_options, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(subscription_options, decoder->ExtractBEInt()); result->payload["subscription_options"] += "{maximum_qos : " + std::to_string(subscription_options & 0x3) + ", no_local : " + std::to_string((subscription_options >> 2) & 0x1) + ", retain_as_published : " + std::to_string((subscription_options >> 3) & 0x1) + @@ -591,8 +608,8 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder, MqttControlPack result->payload["topic_filter"] = ""; payload_length = result->header_fields["remaining_length"] - result->header_fields["variable_header_length"]; while (payload_length > 0) { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(topic_filter_length, decoder->ExtractBEInt()); - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(std::string_view topic_filter, decoder->ExtractString(topic_filter_length)); + PX_ASSIGN_OR_RETURN_ERROR(topic_filter_length, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(std::string_view topic_filter, decoder->ExtractString(topic_filter_length)); if (result->payload["topic_filter"].empty()) { result->payload["topic_filter"] += std::string(topic_filter); } else { @@ -615,7 +632,7 @@ ParseState ParsePayload(Message* result, BinaryDecoder* decoder, MqttControlPack result->payload["reason_code"] = ""; payload_length = result->header_fields["remaining_length"] - result->header_fields["variable_header_length"]; while (payload_length > 0) { - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(reason_code, decoder->ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(reason_code, decoder->ExtractBEInt()); if (result->payload["reason_code"].empty()) { result->payload["reason_code"] += std::to_string(reason_code); } else { @@ -645,12 +662,12 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, // Parsing the fixed header // Control Packet Type extracted from first four bits of the first byte - PX_ASSIGN_OR_RETURN_NEEDS_MORE_DATA(uint8_t control_packet_code_flags, decoder.ExtractBEInt()); + PX_ASSIGN_OR_RETURN_ERROR(uint8_t control_packet_code_flags, decoder.ExtractBEInt()); uint8_t control_packet_code = control_packet_code_flags >> 4; uint8_t control_packet_flags = control_packet_code_flags & 0x0F; MqttControlPacketType control_packet_type = GetControlPacketType(control_packet_code); - result->control_packet_type = control_packet_code; + result->control_packet_type = ControlPacketTypeStrings[control_packet_type]; // Saving the flags if control packet type is PUBLISH if (control_packet_type == MqttControlPacketType::PUBLISH) { @@ -660,26 +677,18 @@ ParseState ParseFrame(message_type_t type, std::string_view* buf, } // Decoding the variable encoding of remaining length field - size_t remaining_length; - if (control_packet_type == MqttControlPacketType::PINGREQ || control_packet_type == MqttControlPacketType::PINGRESP) { - PX_ASSIGN_OR_RETURN_INVALID(remaining_length, decoder.ExtractUVarInt()); - if (remaining_length > 127) { - return ParseState::kInvalid; - } - } else { - if (decoder.BufSize() < 5) { - return ParseState::kNeedsMoreData; - } - PX_ASSIGN_OR_RETURN_INVALID(remaining_length, decoder.ExtractUVarInt()); - if (!VariableEncodingNumBytes(remaining_length).ok()) { - return ParseState::kInvalid; - } + PX_ASSIGN_OR_RETURN_ERROR(size_t remaining_length, decoder.ExtractUVarInt()); + if (!VariableEncodingNumBytes(remaining_length).ok()) { + return ParseState::kInvalid; } - if (decoder.BufSize() < remaining_length) { return ParseState::kNeedsMoreData; } + + if (remaining_length < 0) { + return ParseState::kInvalid; + } result->header_fields["remaining_length"] = remaining_length; ParseState parse_variable_header_state = ParseVariableHeader(result, &decoder, control_packet_type); diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h index 7cc13c60e10..1c5bd68c41f 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse.h @@ -30,7 +30,7 @@ namespace stirling { namespace protocols { /** - * Parses a single MQTT message from the input string. + * Parses a single HTTP message from the input string. */ template<> diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc index a73db7c9ae1..158256f54e4 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/parse_test.cc @@ -693,15 +693,7 @@ TEST_F(MQTTParserTest, Headers) { // message 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64}; - uint8_t kPubackFrame[] = { - // header flags - 0x40, - // message length - 0x03, - // message identifier - 0x00, 0x01, - // reason code - 0x10}; + uint8_t kPubackFrame[] = {0x40, 0x03, 0x00, 0x01, 0x10}; uint8_t kPubrecFrame[] = { // header flags 0x50, @@ -796,7 +788,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kConnectFrame)); result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 1); + EXPECT_EQ(frame.control_packet_type, "CONNECT"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 16); EXPECT_EQ(frame.header_fields["username_flag"], 0); EXPECT_EQ(frame.header_fields["password_flag"], 0); @@ -810,7 +802,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kConnackFrame)); result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 2); + EXPECT_EQ(frame.control_packet_type, "CONNACK"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 53); EXPECT_EQ(frame.header_fields["session_present"], 0); EXPECT_EQ(frame.header_fields["reason_code"], 0); @@ -819,7 +811,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kPublishFrame)); result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 3); + EXPECT_EQ(frame.control_packet_type, "PUBLISH"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 26); EXPECT_EQ(frame.dup, false); EXPECT_EQ(frame.retain, false); @@ -829,8 +821,8 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kPubackFrame)); result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); - ASSERT_EQ(result_state, ParseState::kNeedsMoreData); - EXPECT_EQ(frame.control_packet_type, 4); + ASSERT_EQ(result_state, ParseState::kSuccess); + EXPECT_EQ(frame.control_packet_type, "PUBACK"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 3); EXPECT_EQ(frame.header_fields["packet_identifier"], 1); frame = Message(); @@ -838,7 +830,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kPubrecFrame)); result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 5); + EXPECT_EQ(frame.control_packet_type, "PUBREC"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 2); EXPECT_EQ(frame.header_fields["packet_identifier"], 1); frame = Message(); @@ -846,7 +838,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kPubrelFrame)); result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 6); + EXPECT_EQ(frame.control_packet_type, "PUBREL"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 2); EXPECT_EQ(frame.header_fields["packet_identifier"], 1); frame = Message(); @@ -854,7 +846,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kPubcompFrame)); result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 7); + EXPECT_EQ(frame.control_packet_type, "PUBCOMP"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 2); EXPECT_EQ(frame.header_fields["packet_identifier"], 1); frame = Message(); @@ -862,7 +854,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kSubscribeFrame)); result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 8); + EXPECT_EQ(frame.control_packet_type, "SUBSCRIBE"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 16); EXPECT_EQ(frame.header_fields["packet_identifier"], 1); frame = Message(); @@ -870,7 +862,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kSubackFrame)); result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 9); + EXPECT_EQ(frame.control_packet_type, "SUBACK"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 4); EXPECT_EQ(frame.payload["reason_code"], "0"); EXPECT_EQ(frame.header_fields["packet_identifier"], 1); @@ -879,7 +871,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kUnsubscribeFrame)); result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 10); + EXPECT_EQ(frame.control_packet_type, "UNSUBSCRIBE"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 15); EXPECT_EQ(frame.header_fields["packet_identifier"], 2); frame = Message(); @@ -887,7 +879,7 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kUnsubackFrame)); result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 11); + EXPECT_EQ(frame.control_packet_type, "UNSUBACK"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 4); EXPECT_EQ(frame.header_fields["packet_identifier"], 2); frame = Message(); @@ -895,21 +887,21 @@ TEST_F(MQTTParserTest, Headers) { frame_view = CreateStringView(CharArrayStringView(kPingreqFrame)); result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 12); + EXPECT_EQ(frame.control_packet_type, "PINGREQ"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 0); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kPingrespFrame)); result_state = ParseFrame(message_type_t::kResponse, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 13); + EXPECT_EQ(frame.control_packet_type, "PINGRESP"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 0); frame = Message(); frame_view = CreateStringView(CharArrayStringView(kDisconnectFrame)); result_state = ParseFrame(message_type_t::kRequest, &frame_view, &frame); ASSERT_EQ(result_state, ParseState::kSuccess); - EXPECT_EQ(frame.control_packet_type, 14); + EXPECT_EQ(frame.control_packet_type, "DISCONNECT"); EXPECT_EQ(frame.header_fields["remaining_length"], (size_t) 1); EXPECT_EQ(frame.header_fields["reason_code"], 4); frame = Message(); diff --git a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h index 698209cc21a..97c12adff4d 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h @@ -19,7 +19,6 @@ #pragma once #include "src/common/base/utils.h" -#include "src/common/json/json.h" #include "src/stirling/source_connectors/socket_tracer/protocols/common/event_parser.h" // For FrameBase. namespace px { @@ -27,15 +26,13 @@ namespace stirling { namespace protocols { namespace mqtt { -using ::px::utils::ToJSONString; - // The protocol specification : https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.pdf // This supports MQTTv5 struct Message: public FrameBase { message_type_t type = message_type_t::kUnknown; - uint8_t control_packet_type = 0xff; + std::string control_packet_type = "UNKNOWN"; bool dup; bool retain; @@ -43,17 +40,53 @@ struct Message: public FrameBase { std::map header_fields; std::map properties, payload; + template + static std::string MapToString(const std::map& inputMap) { + std::string result = "{"; + for (const auto& entry : inputMap) { + result += entry.first + ": "; + if constexpr (std::is_same_v) { + result += std::to_string(entry.second); + } else if constexpr (std::is_same_v) { + result += entry.second; + } + result += ", "; + } + if (!inputMap.empty()) { + result = result.substr(0, result.size() - 2); // Remove the trailing ", " + } + result += "}"; + return result; + } + size_t ByteSize() const override { return sizeof(Message) + payload.size(); } std::string ToString() const override { + std::string header_fields_str = "{"; + for (const auto& entry : properties) { + header_fields_str += entry.first + ": " + std::string(entry.second) + ", "; + } + header_fields_str += "}"; + + std::string properties_str = "{"; + for (const auto& entry : properties) { + properties_str += entry.first + ": " + entry.second + ", "; + } + properties_str += "}"; + + std::string payload_str = "{"; + for (const auto& entry : properties) { + payload_str += entry.first + ": " + entry.second + ", "; + } + payload_str += "}"; + return absl::Substitute( "Message: {type: $0, control_packet_type: $1, dup: $2, retain: $3, header_fields: $4, " "payload: $5, properties: $6}", magic_enum::enum_name(type), control_packet_type, dup, retain, - ToJSONString(header_fields), ToJSONString(payload), - ToJSONString(properties)); + header_fields_str, payload_str, properties_str); } }; @@ -62,7 +95,7 @@ struct Message: public FrameBase { //----------------------------------------------------------------------------- /** - * Record is the primary output of the MQTT stitcher. + * Record is the primary output of the http stitcher. */ struct Record{ Message req; diff --git a/src/stirling/source_connectors/socket_tracer/protocols/stitchers.h b/src/stirling/source_connectors/socket_tracer/protocols/stitchers.h index fba7325b6f3..c899c372b19 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/stitchers.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/stitchers.h @@ -30,3 +30,4 @@ #include "src/stirling/source_connectors/socket_tracer/protocols/nats/stitcher.h" // IWYU pragma: export #include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/stitcher.h" // IWYU pragma: export #include "src/stirling/source_connectors/socket_tracer/protocols/redis/stitcher.h" // IWYU pragma: export +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/stitcher.h" // IWYU pragma: export diff --git a/src/stirling/source_connectors/socket_tracer/protocols/types.h b/src/stirling/source_connectors/socket_tracer/protocols/types.h index a0e891bbfd5..f51fbfa76ef 100644 --- a/src/stirling/source_connectors/socket_tracer/protocols/types.h +++ b/src/stirling/source_connectors/socket_tracer/protocols/types.h @@ -32,6 +32,7 @@ #include "src/stirling/source_connectors/socket_tracer/protocols/nats/types.h" #include "src/stirling/source_connectors/socket_tracer/protocols/pgsql/types.h" #include "src/stirling/source_connectors/socket_tracer/protocols/redis/types.h" +#include "src/stirling/source_connectors/socket_tracer/protocols/mqtt/types.h" namespace px { namespace stirling { @@ -49,7 +50,8 @@ using FrameDequeVariant = std::variant, std::deque, std::deque, - std::deque; + std::deque, + std::deque>; // clang-format off } // namespace protocols diff --git a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc index 489922b884d..40809e0830f 100644 --- a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc +++ b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.cc @@ -112,6 +112,9 @@ DEFINE_int32(stirling_enable_mux_tracing, DEFINE_int32(stirling_enable_amqp_tracing, gflags::Int32FromEnv("PX_STIRLING_ENABLE_AMQP_TRACING", px::stirling::TraceMode::On), "If true, stirling will trace and process AMQP messages."); +DEFINE_int32(stirling_enable_mqtt_tracing, + gflags::Int32FromEnv("PX_STIRLING_ENABLE_MQTT_TRACING", px::stirling::TraceMode::On), + "If true, stirling will trace and process MQTT messages."); DEFINE_bool(stirling_disable_golang_tls_tracing, gflags::BoolFromEnv("PX_STIRLING_DISABLE_GOLANG_TLS_TRACING", false), @@ -271,6 +274,10 @@ void SocketTraceConnector::InitProtocolTransferSpecs() { kAMQPTableNum, {kRoleClient, kRoleServer}, TRANSFER_STREAM_PROTOCOL(amqp)}}, + {kProtocolMQTT, TransferSpec{FLAGS_stirling_enable_mqtt_tracing, + kMQTTTableNum, + {kRoleClient, kRoleServer}, + TRANSFER_STREAM_PROTOCOL(mqtt)}}, {kProtocolUnknown, TransferSpec{/* trace_mode */ px::stirling::TraceMode::Off, /* table_num */ static_cast(-1), /* trace_roles */ {}, @@ -442,6 +449,7 @@ Status SocketTraceConnector::InitBPF() { absl::StrCat("-DENABLE_REDIS_TRACING=", protocol_transfer_specs_[kProtocolRedis].enabled), absl::StrCat("-DENABLE_NATS_TRACING=", protocol_transfer_specs_[kProtocolNATS].enabled), absl::StrCat("-DENABLE_AMQP_TRACING=", protocol_transfer_specs_[kProtocolAMQP].enabled), + absl::StrCat("-DENABLE_MQTT_TRACING=", protocol_transfer_specs_[kProtocolMQTT].enabled), absl::StrCat("-DENABLE_MONGO_TRACING=", "true"), }; PX_RETURN_IF_ERROR(bcc_->InitBPFProgram(socket_trace_bcc_script, defines)); @@ -1566,6 +1574,34 @@ void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracke #endif } +template <> +void SocketTraceConnector::AppendMessage(ConnectorContext* ctx, const ConnTracker& conn_tracker, + protocols::mqtt::Record record, DataTable* data_table) { + md::UPID upid(ctx->GetASID(), conn_tracker.conn_id().upid.pid, + conn_tracker.conn_id().upid.start_time_ticks); + + endpoint_role_t role = conn_tracker.role(); + DataTable::RecordBuilder<&kMQTTTable> r(data_table, record.resp.timestamp_ns); + r.Append(record.req.timestamp_ns); + r.Append(upid.value()); + r.Append(conn_tracker.remote_endpoint().AddrStr()); + r.Append(conn_tracker.remote_endpoint().port()); + r.Append(role); + r.Append(record.req.control_packet_type); + r.Append(protocols::mqtt::Message::MapToString(record.req.header_fields)); + r.Append(protocols::mqtt::Message::MapToString(record.req.properties)); + r.Append(protocols::mqtt::Message::MapToString(record.req.payload)); + r.Append(record.resp.control_packet_type); + r.Append(protocols::mqtt::Message::MapToString(record.resp.header_fields)); + r.Append(protocols::mqtt::Message::MapToString(record.resp.properties)); + r.Append(protocols::mqtt::Message::MapToString(record.resp.payload)); + r.Append( + CalculateLatency(record.req.timestamp_ns, record.resp.timestamp_ns)); +#ifndef NDEBUG + r.Append(PXInfoString(conn_tracker, record)); +#endif +} + void SocketTraceConnector::SetupOutput(const std::filesystem::path& path) { DCHECK(!path.empty()); diff --git a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.h b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.h index c010a1e8b2f..72e059cd911 100644 --- a/src/stirling/source_connectors/socket_tracer/socket_trace_connector.h +++ b/src/stirling/source_connectors/socket_tracer/socket_trace_connector.h @@ -65,6 +65,7 @@ DECLARE_int32(stirling_enable_nats_tracing); DECLARE_int32(stirling_enable_kafka_tracing); DECLARE_int32(stirling_enable_mux_tracing); DECLARE_int32(stirling_enable_amqp_tracing); +DECLARE_int32(stirling_enable_mqtt_tracing); DECLARE_bool(stirling_disable_self_tracing); DECLARE_string(stirling_role_to_trace); @@ -95,7 +96,7 @@ class SocketTraceConnector : public BCCSourceConnector { static constexpr std::string_view kName = "socket_tracer"; static constexpr auto kTables = MakeArray(kConnStatsTable, kHTTPTable, kMySQLTable, kCQLTable, kPGSQLTable, kDNSTable, - kRedisTable, kNATSTable, kKafkaTable, kMuxTable, kAMQPTable); + kRedisTable, kNATSTable, kKafkaTable, kMuxTable, kAMQPTable, kMQTTTable); static constexpr uint32_t kConnStatsTableNum = TableNum(kTables, kConnStatsTable); static constexpr uint32_t kHTTPTableNum = TableNum(kTables, kHTTPTable); @@ -108,6 +109,7 @@ class SocketTraceConnector : public BCCSourceConnector { static constexpr uint32_t kKafkaTableNum = TableNum(kTables, kKafkaTable); static constexpr uint32_t kMuxTableNum = TableNum(kTables, kMuxTable); static constexpr uint32_t kAMQPTableNum = TableNum(kTables, kAMQPTable); + static constexpr uint32_t kMQTTTableNum = TableNum(kTables, kMQTTTable); static constexpr auto kSamplingPeriod = std::chrono::milliseconds{200}; // TODO(yzhao): This is not used right now. Eventually use this to control data push frequency.