From 45017d6e9530ccc71a44ced20c1332ca18a135eb Mon Sep 17 00:00:00 2001 From: taras Date: Fri, 13 Sep 2024 00:21:29 +0200 Subject: [PATCH] Fix exception on bytearray object passed to data_received callback --- picows/picows.pyx | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/picows/picows.pyx b/picows/picows.pyx index 051f6b2..198e5c6 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -136,6 +136,23 @@ cdef void _mask_payload(uint8_t* input, size_t input_len, uint32_t mask) noexcep input[i] ^= mask_buf[i] +cdef _unpack_bytes_like(object bytes_like_obj, char** msg_ptr_out, Py_ssize_t* msg_size_out): + cdef Py_buffer msg_buffer + + if PyBytes_CheckExact(bytes_like_obj): + msg_ptr_out[0] = PyBytes_AS_STRING(bytes_like_obj) + msg_size_out[0] = PyBytes_GET_SIZE(bytes_like_obj) + elif PyByteArray_CheckExact(bytes_like_obj): + msg_ptr_out[0] = PyByteArray_AS_STRING(bytes_like_obj) + msg_size_out[0] = PyByteArray_GET_SIZE(bytes_like_obj) + else: + PyObject_GetBuffer(bytes_like_obj, &msg_buffer, PyBUF_SIMPLE) + msg_ptr_out[0] = msg_buffer.buf + msg_size_out[0] = msg_buffer.len + # We can already release because we still keep the reference to the message + PyBuffer_Release(&msg_buffer) + + @cython.no_gc @cython.freelist(64) cdef class WSFrame: @@ -402,25 +419,14 @@ cdef class WSTransport: is compressed. """ cdef: - Py_buffer msg_buffer char* msg_ptr Py_ssize_t msg_length if message is None: msg_ptr = b"" msg_length = 0 - elif PyBytes_CheckExact(message): - msg_ptr = PyBytes_AS_STRING(message) - msg_length = PyBytes_GET_SIZE(message) - elif PyByteArray_CheckExact(message): - msg_ptr = PyByteArray_AS_STRING(message) - msg_length = PyByteArray_GET_SIZE(message) else: - PyObject_GetBuffer(message, &msg_buffer, PyBUF_SIMPLE) - msg_ptr = msg_buffer.buf - msg_length = msg_buffer.len - # We can already release because we still keep the reference to the message - PyBuffer_Release(&msg_buffer) + _unpack_bytes_like(message, &msg_ptr, &msg_length) cdef: uint8_t first_byte = msg_type @@ -761,10 +767,12 @@ cdef class WSProtocol: if self.listener is not None: self.listener.resume_writing() - def data_received(self, bytes data): + def data_received(self, data): cdef: - const char * ptr = PyBytes_AS_STRING(data) - size_t sz = PyBytes_GET_SIZE(data) + char* ptr + Py_ssize_t sz + + _unpack_bytes_like(data, &ptr, &sz) # Leave some space for simd parsers like simdjson, they require extra # space beyond normal data to make sure that vector reads