Skip to content

Commit

Permalink
Fix exception on bytearray object passed to data_received callback
Browse files Browse the repository at this point in the history
  • Loading branch information
taras committed Sep 12, 2024
1 parent 477a969 commit 45017d6
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions picows/picows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,23 @@ cdef void _mask_payload(uint8_t* input, size_t input_len, uint32_t mask) noexcep
input[i] ^= mask_buf[i]


cdef _unpack_bytes_like(object bytes_like_obj, char** msg_ptr_out, Py_ssize_t* msg_size_out):
cdef Py_buffer msg_buffer

if PyBytes_CheckExact(bytes_like_obj):
msg_ptr_out[0] = PyBytes_AS_STRING(bytes_like_obj)
msg_size_out[0] = PyBytes_GET_SIZE(bytes_like_obj)
elif PyByteArray_CheckExact(bytes_like_obj):
msg_ptr_out[0] = PyByteArray_AS_STRING(bytes_like_obj)
msg_size_out[0] = PyByteArray_GET_SIZE(bytes_like_obj)
else:
PyObject_GetBuffer(bytes_like_obj, &msg_buffer, PyBUF_SIMPLE)
msg_ptr_out[0] = <char*>msg_buffer.buf
msg_size_out[0] = msg_buffer.len
# We can already release because we still keep the reference to the message
PyBuffer_Release(&msg_buffer)


@cython.no_gc
@cython.freelist(64)
cdef class WSFrame:
Expand Down Expand Up @@ -402,25 +419,14 @@ cdef class WSTransport:
is compressed.
"""
cdef:
Py_buffer msg_buffer
char* msg_ptr
Py_ssize_t msg_length

if message is None:
msg_ptr = b""
msg_length = 0
elif PyBytes_CheckExact(message):
msg_ptr = PyBytes_AS_STRING(message)
msg_length = PyBytes_GET_SIZE(message)
elif PyByteArray_CheckExact(message):
msg_ptr = PyByteArray_AS_STRING(message)
msg_length = PyByteArray_GET_SIZE(message)
else:
PyObject_GetBuffer(message, &msg_buffer, PyBUF_SIMPLE)
msg_ptr = <char*>msg_buffer.buf
msg_length = msg_buffer.len
# We can already release because we still keep the reference to the message
PyBuffer_Release(&msg_buffer)
_unpack_bytes_like(message, &msg_ptr, &msg_length)

cdef:
uint8_t first_byte = <uint8_t>msg_type
Expand Down Expand Up @@ -761,10 +767,12 @@ cdef class WSProtocol:
if self.listener is not None:
self.listener.resume_writing()

def data_received(self, bytes data):
def data_received(self, data):
cdef:
const char * ptr = PyBytes_AS_STRING(data)
size_t sz = PyBytes_GET_SIZE(data)
char* ptr
Py_ssize_t sz

_unpack_bytes_like(data, &ptr, &sz)

# Leave some space for simd parsers like simdjson, they require extra
# space beyond normal data to make sure that vector reads
Expand Down

0 comments on commit 45017d6

Please sign in to comment.