Skip to content

Commit

Permalink
Merge pull request #4838 from d3matt/odd_nonblocking_receive_incorrec…
Browse files Browse the repository at this point in the history
…t_mask

fix(Net) bad mask with odd number of bytes
  • Loading branch information
obiltschnig authored Jan 8, 2025
2 parents bd7be38 + 0cc3ab4 commit 5652837
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
3 changes: 2 additions & 1 deletion Net/include/Poco/Net/WebSocketImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class Net_API WebSocketImpl: public StreamSocketImpl
int payloadLength = 0;
int remainingPayloadLength = 0;
Poco::Buffer<char> payload{0};
int maskOffset = 0;
};

struct SendState
Expand All @@ -133,7 +134,7 @@ class Net_API WebSocketImpl: public StreamSocketImpl

int peekHeader(ReceiveState& receiveState);
void skipHeader(int headerLength);
int receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask);
int receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask, int maskOffset);
int receiveNBytes(void* buffer, int length);
int receiveSomeBytes(char* buffer, int length);
int peekSomeBytes(char* buffer, int length);
Expand Down
16 changes: 10 additions & 6 deletions Net/src/WebSocketImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,14 @@ void WebSocketImpl::setMaxPayloadSize(int maxPayloadSize)
}


int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask)
int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask, int maskOffset)
{
int received = receiveNBytes(reinterpret_cast<char*>(buffer), payloadLength);
if (received > 0 && useMask)
{
for (int i = 0; i < received; i++)
{
buffer[i] ^= mask[i % MASK_LENGTH];
buffer[i] ^= mask[(i + maskOffset) % MASK_LENGTH];
}
}
return received;
Expand Down Expand Up @@ -297,7 +297,7 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int)

skipHeader(_receiveState.headerLength);

if (receivePayload(reinterpret_cast<char*>(buffer), payloadLength, _receiveState.mask, _receiveState.useMask) != payloadLength)
if (receivePayload(reinterpret_cast<char*>(buffer), payloadLength, _receiveState.mask, _receiveState.useMask, 0) != payloadLength)
throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME);

return payloadLength;
Expand Down Expand Up @@ -326,17 +326,19 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int)
throw WebSocketException(Poco::format("Insufficient buffer for payload size %d", _receiveState.payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
}
int payloadOffset = _receiveState.payloadLength - _receiveState.remainingPayloadLength;
int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask);
int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask, _receiveState.maskOffset);
if (n > 0)
{
_receiveState.remainingPayloadLength -= n;
if (_receiveState.remainingPayloadLength == 0)
{
_receiveState.maskOffset = 0;
std::memcpy(buffer, _receiveState.payload.begin(), _receiveState.payloadLength);
return _receiveState.payloadLength;
}
else
{
_receiveState.maskOffset += n;
return -1;
}
}
Expand Down Expand Up @@ -369,7 +371,7 @@ int WebSocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int, const Poco::Tim
std::size_t oldSize = buffer.size();
buffer.resize(oldSize + payloadLength);

if (receivePayload(buffer.begin() + oldSize, payloadLength, _receiveState.mask, _receiveState.useMask) != payloadLength)
if (receivePayload(buffer.begin() + oldSize, payloadLength, _receiveState.mask, _receiveState.useMask, 0) != payloadLength)
throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME);

return payloadLength;
Expand All @@ -387,12 +389,13 @@ int WebSocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int, const Poco::Tim
_receiveState.payload.resize(payloadLength, false);
}
int payloadOffset = _receiveState.payloadLength - _receiveState.remainingPayloadLength;
int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask);
int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask, _receiveState.maskOffset);
if (n > 0)
{
_receiveState.remainingPayloadLength -= n;
if (_receiveState.remainingPayloadLength == 0)
{
_receiveState.maskOffset = 0;
std::size_t oldSize = buffer.size();
buffer.resize(oldSize + _receiveState.payloadLength);

Expand All @@ -401,6 +404,7 @@ int WebSocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int, const Poco::Tim
}
else
{
_receiveState.maskOffset += n;
return -1;
}
}
Expand Down

0 comments on commit 5652837

Please sign in to comment.