diff --git a/subsys/net/lib/mqtt/mqtt_decoder.c b/subsys/net/lib/mqtt/mqtt_decoder.c index 7469a762baca..23915afd9dd1 100644 --- a/subsys/net/lib/mqtt/mqtt_decoder.c +++ b/subsys/net/lib/mqtt/mqtt_decoder.c @@ -256,6 +256,13 @@ int publish_decode(u8_t flags, u32_t var_length, struct buf_ctx *buf, var_header_length += sizeof(u16_t); } + if (var_length < var_header_length) { + MQTT_ERR("Corrupted PUBLISH message, header length (%u) larger " + "than total length (%u)", var_header_length, + var_length); + return -EINVAL; + } + param->message.payload.data = NULL; param->message.payload.len = var_length - var_header_length; diff --git a/tests/net/lib/mqtt_packet/src/mqtt_packet.c b/tests/net/lib/mqtt_packet/src/mqtt_packet.c index fe57ebe5055c..1f3f9fd421a4 100644 --- a/tests/net/lib/mqtt_packet/src/mqtt_packet.c +++ b/tests/net/lib/mqtt_packet/src/mqtt_packet.c @@ -109,6 +109,15 @@ static int eval_msg_connect(struct mqtt_test *mqtt_test); */ static int eval_msg_publish(struct mqtt_test *mqtt_test); +/** + * @brief eval_msg_corrupted_publish Evaluate the given mqtt_test against the + * corrupted publish message. + * @param [in] mqtt_test MQTT test structure + * @return TC_PASS on success + * @return TC_FAIL on error + */ +static int eval_msg_corrupted_publish(struct mqtt_test *mqtt_test); + /** * @brief eval_msg_subscribe Evaluate the given mqtt_test against the * subscribe packing/unpacking routines. @@ -452,6 +461,14 @@ static ZTEST_DMEM struct mqtt_publish_param msg_publish4 = { .message.payload.len = 2, }; +static ZTEST_DMEM +u8_t publish_corrupted[] = {0x30, 0x07, 0x00, 0x07, 0x73, 0x65, 0x6e, 0x73, + 0x6f, 0x72, 0x73, 0x00, 0x01, 0x4f, 0x4b}; +static ZTEST_DMEM struct buf_ctx publish_corrupted_buf = { + .cur = publish_corrupted, + .end = publish_corrupted + sizeof(publish_corrupted) +}; + /* * MQTT SUBSCRIBE msg: * pkt_id: 1, topic: sensors, qos: 0 @@ -622,6 +639,9 @@ struct mqtt_test mqtt_tests[] = { .ctx = &msg_publish4, .eval_fcn = eval_msg_publish, .expected = publish4, .expected_len = sizeof(publish4)}, + {.test_name = "PUBLISH, corrupted message length (smaller than topic)", + .ctx = &publish_corrupted_buf, .eval_fcn = eval_msg_corrupted_publish}, + {.test_name = "SUBSCRIBE, one topic, qos = 0", .ctx = &msg_subscribe1, .eval_fcn = eval_msg_subscribe, .expected = subscribe1, .expected_len = sizeof(subscribe1)}, @@ -829,6 +849,23 @@ static int eval_msg_publish(struct mqtt_test *mqtt_test) return TC_PASS; } +static int eval_msg_corrupted_publish(struct mqtt_test *mqtt_test) +{ + struct buf_ctx *buf = (struct buf_ctx *)mqtt_test->ctx; + int rc; + u8_t type_and_flags; + u32_t length; + struct mqtt_publish_param dec_param; + + rc = fixed_header_decode(buf, &type_and_flags, &length); + zassert_equal(rc, 0, "fixed_header_decode failed"); + + rc = publish_decode(type_and_flags, length, buf, &dec_param); + zassert_equal(rc, -EINVAL, "publish_decode should fail"); + + return TC_PASS; +} + static int eval_msg_subscribe(struct mqtt_test *mqtt_test) { struct mqtt_subscription_list *param =