diff --git a/src/protocols/MQTT/MQTT.cpp b/src/protocols/MQTT/MQTT.cpp index cb62321c..fb4791f7 100644 --- a/src/protocols/MQTT/MQTT.cpp +++ b/src/protocols/MQTT/MQTT.cpp @@ -411,10 +411,11 @@ int16_t MQTTClient::check(void (*func)(const char*, const char*)) { _tl->receive(dataIn, numBytes); if(dataIn[0] == MQTT_PUBLISH << 4) { // TODO: properly decode remaining length - uint8_t remainingLength = dataIn[1]; + uint8_t remLenFieldLen = 0; + uint32_t remainingLength = decodeLength(dataIn + 1, remLenFieldLen); // get the topic - size_t topicLength = dataIn[3] | dataIn[2] << 8; + size_t topicLength = dataIn[remLenFieldLen + 2] | dataIn[remLenFieldLen + 1] << 8; char* topic = new char[topicLength + 1]; memcpy(topic, dataIn + 4, topicLength); topic[topicLength] = 0x00; @@ -422,7 +423,7 @@ int16_t MQTTClient::check(void (*func)(const char*, const char*)) { // get the message size_t messageLength = remainingLength - topicLength - 2; char* message = new char[messageLength + 1]; - memcpy(message, dataIn + 4 + topicLength, messageLength); + memcpy(message, dataIn + remLenFieldLen + 3 + topicLength, messageLength); message[messageLength] = 0x00; // execute the callback function provided by user @@ -452,7 +453,7 @@ size_t MQTTClient::encodeLength(uint32_t len, uint8_t* encoded) { return(i); } -uint32_t MQTTClient::decodeLength(uint8_t* encoded) { +uint32_t MQTTClient::decodeLength(uint8_t* encoded, uint8_t& numBytes) { // algorithm to decode packet length as per MQTT specification 3.1.1 uint32_t mult = 1; uint32_t len = 0; @@ -464,7 +465,8 @@ uint32_t MQTTClient::decodeLength(uint8_t* encoded) { // malformed remaining length return(0); } - } while((encoded[i] & 128) != 0); + } while((encoded[i++] & 128) != 0); + numBytes = i; return len; } diff --git a/src/protocols/MQTT/MQTT.h b/src/protocols/MQTT/MQTT.h index 87b8444d..6f21f585 100644 --- a/src/protocols/MQTT/MQTT.h +++ b/src/protocols/MQTT/MQTT.h @@ -139,7 +139,7 @@ class MQTTClient { uint16_t _packetId; static size_t encodeLength(uint32_t len, uint8_t* encoded); - static uint32_t decodeLength(uint8_t* encoded); + static uint32_t decodeLength(uint8_t* encoded, uint8_t& numBytes); }; #endif