Skip to content

Commit

Permalink
Introduce SqlState. Fix MyDeserialize impl for ErrPacket
Browse files Browse the repository at this point in the history
  • Loading branch information
blackbeam committed Jan 15, 2024
1 parent 98c794a commit 8afb0d4
Showing 1 changed file with 85 additions and 51 deletions.
136 changes: 85 additions & 51 deletions src/packets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use regex::bytes::Regex;
use smallvec::SmallVec;
use uuid::Uuid;

use std::convert::TryInto;
use std::str::FromStr;
use std::{
borrow::Cow, cmp::max, collections::HashMap, convert::TryFrom, fmt, io, marker::PhantomData,
Expand Down Expand Up @@ -838,11 +837,14 @@ impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
sbuf.parse_unchecked::<ErrPacketHeader>(())?;
let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;

// We assume that CLIENT_PROTOCOL_41 was set
if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
buf.parse(()).map(ErrPacket::Progress)
} else {
buf.parse(*code).map(ErrPacket::Error)
buf.parse((
*code,
capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
))
.map(ErrPacket::Error)
}
}
}
Expand Down Expand Up @@ -872,18 +874,70 @@ impl<'a> fmt::Display for ErrPacket<'a> {
}
}

define_header!(
SqlStateMarker,
InvalidSqlStateMarker("Invalid SqlStateMarker value"),
b'#'
);

/// MySql error state.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct SqlState {
__state_marker: SqlStateMarker,
state: [u8; 5],
}

impl SqlState {
/// Creates new sql state.
pub fn new(state: [u8; 5]) -> Self {
Self {
__state_marker: SqlStateMarker::new(),
state,
}
}

/// Returns an sql state as bytes.
pub fn as_bytes(&self) -> [u8; 5] {
self.state
}

/// Returns an sql state as a string (lossy converted).
pub fn as_str(&self) -> Cow<'_, str> {
String::from_utf8_lossy(&self.state)
}
}

impl<'de> MyDeserialize<'de> for SqlState {
const SIZE: Option<usize> = Some(6);
type Ctx = ();

fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
Ok(Self {
__state_marker: buf.parse(())?,
state: buf.parse(())?,
})
}
}

impl MySerialize for SqlState {
fn serialize(&self, buf: &mut Vec<u8>) {
self.__state_marker.serialize(buf);
self.state.serialize(buf);
}
}

/// MySql error packet.
///
/// May hold an error or a progress report.
#[derive(Debug, Clone, PartialEq)]
pub struct ServerError<'a> {
code: RawInt<LeU16>,
state: Option<[u8; 5]>,
state: Option<SqlState>,
message: RawBytes<'a, EofBytes>,
}

impl<'a> ServerError<'a> {
pub fn new(code: u16, state: Option<[u8; 5]>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
Self {
code: RawInt::new(code),
state,
Expand All @@ -897,13 +951,8 @@ impl<'a> ServerError<'a> {
}

/// Returns an sql state.
pub fn sql_state_ref(&self) -> Option<[u8; 5]> {
self.state
}

/// Returns an sql state as a string (lossy converted).
pub fn sql_state_str(&self) -> Option<Cow<'_, str>> {
self.state.as_ref().map(|s| String::from_utf8_lossy(s))
pub fn sql_state_ref(&self) -> Option<&SqlState> {
self.state.as_ref()
}

/// Returns an error message.
Expand All @@ -925,60 +974,45 @@ impl<'a> ServerError<'a> {
}
}

impl<'de> MyDeserialize<'de> for Option<[u8; 5]> {
const SIZE: Option<usize> = Some(5);
type Ctx = usize;

fn deserialize(len: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
match buf.checked_eat(len) {
Some(b) => Ok(Some(b.try_into().unwrap())),
None => Err(unexpected_buf_eof()),
}
}
}

const STATE_CODE_LEN: usize = 5;

impl<'de> MyDeserialize<'de> for ServerError<'de> {
const SIZE: Option<usize> = None;
/// An error packet error code.
type Ctx = u16;

fn deserialize(code: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
match buf.0[0] {
b'#' => {
buf.skip(1);
Ok(ServerError {
code: RawInt::new(code),
state: buf.parse(STATE_CODE_LEN)?,
message: buf.parse(())?,
})
/// An error packet error code + whether CLIENT_PROTOCOL_41 capability was negotiated.
type Ctx = (u16, bool);

fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
let server_error = if protocol_41 {
ServerError {
code: RawInt::new(code),
state: Some(buf.parse(())?),
message: buf.parse(())?,
}
_ => Ok(ServerError {
} else {
ServerError {
code: RawInt::new(code),
state: None,
message: buf.parse(())?,
}),
}
}
};
Ok(server_error)
}
}

impl MySerialize for ServerError<'_> {
fn serialize(&self, buf: &mut Vec<u8>) {
if let Some(state) = &self.state {
buf.put_u8(b'#');
buf.put_slice(state);
state.serialize(buf);
}
self.message.serialize(buf);
}
}

impl fmt::Display for ServerError<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sql_state_str = match self.sql_state_str() {
Some(s) => format!(" ({})", s.as_ref()),
None => "".to_string(),
};
let sql_state_str = self
.sql_state_ref()
.map(|s| format!(" ({})", s.as_str()))
.unwrap_or_default();

write!(
f,
"ERROR {}{}: {}",
Expand Down Expand Up @@ -3543,20 +3577,20 @@ mod test {
const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";

let err_packet =
ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET)).unwrap();
ErrPacket::deserialize(CapabilityFlags::CLIENT_PROTOCOL_41, &mut ParseBuf(ERR_PACKET)).unwrap();
let err_packet = err_packet.server_error();
assert_eq!(err_packet.error_code(), 1096);
assert_eq!(err_packet.sql_state_str().unwrap(), "HY000");
assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
assert_eq!(err_packet.message_str(), "No tables used");

let err_packet = ErrPacket::deserialize(
CapabilityFlags::CLIENT_PROTOCOL_41,
CapabilityFlags::empty(),
&mut ParseBuf(ERR_PACKET_NO_STATE),
)
.unwrap();
let server_error = err_packet.server_error();
assert_eq!(server_error.error_code(), 1040);
assert_eq!(server_error.sql_state_str(), None);
assert_eq!(server_error.sql_state_ref(), None);
assert_eq!(server_error.message_str(), "Too many connections");

let err_packet = ErrPacket::deserialize(
Expand Down

0 comments on commit 8afb0d4

Please sign in to comment.