Skip to content

Commit

Permalink
Merge pull request #114 from Will-Low/master
Browse files Browse the repository at this point in the history
Make SQLSTATE optional in ServerError
  • Loading branch information
blackbeam authored Jan 15, 2024
2 parents 76bd713 + 8990ff2 commit d9d512f
Showing 1 changed file with 96 additions and 40 deletions.
136 changes: 96 additions & 40 deletions src/packets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -838,11 +838,14 @@ impl<'de> MyDeserialize<'de> for ErrPacket<'de> {
sbuf.parse_unchecked::<ErrPacketHeader>(())?;
let code: RawInt<LeU16> = sbuf.parse_unchecked(())?;

// We assume that CLIENT_PROTOCOL_41 was set
if *code == 0xFFFF && capabilities.contains(CapabilityFlags::CLIENT_PROGRESS_OBSOLETE) {
buf.parse(()).map(ErrPacket::Progress)
} else {
buf.parse(*code).map(ErrPacket::Error)
buf.parse((
*code,
capabilities.contains(CapabilityFlags::CLIENT_PROTOCOL_41),
))
.map(ErrPacket::Error)
}
}
}
Expand Down Expand Up @@ -872,18 +875,70 @@ impl<'a> fmt::Display for ErrPacket<'a> {
}
}

define_header!(
SqlStateMarker,
InvalidSqlStateMarker("Invalid SqlStateMarker value"),
b'#'
);

/// MySql error state.
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct SqlState {
__state_marker: SqlStateMarker,
state: [u8; 5],
}

impl SqlState {
/// Creates new sql state.
pub fn new(state: [u8; 5]) -> Self {
Self {
__state_marker: SqlStateMarker::new(),
state,
}
}

/// Returns an sql state as bytes.
pub fn as_bytes(&self) -> [u8; 5] {
self.state
}

/// Returns an sql state as a string (lossy converted).
pub fn as_str(&self) -> Cow<'_, str> {
String::from_utf8_lossy(&self.state)
}
}

impl<'de> MyDeserialize<'de> for SqlState {
const SIZE: Option<usize> = Some(6);
type Ctx = ();

fn deserialize((): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
Ok(Self {
__state_marker: buf.parse(())?,
state: buf.parse(())?,
})
}
}

impl MySerialize for SqlState {
fn serialize(&self, buf: &mut Vec<u8>) {
self.__state_marker.serialize(buf);
self.state.serialize(buf);
}
}

/// MySql error packet.
///
/// May hold an error or a progress report.
#[derive(Debug, Clone, PartialEq)]
pub struct ServerError<'a> {
code: RawInt<LeU16>,
state: [u8; 5],
state: Option<SqlState>,
message: RawBytes<'a, EofBytes>,
}

impl<'a> ServerError<'a> {
pub fn new(code: u16, state: [u8; 5], msg: impl Into<Cow<'a, [u8]>>) -> Self {
pub fn new(code: u16, state: Option<SqlState>, msg: impl Into<Cow<'a, [u8]>>) -> Self {
Self {
code: RawInt::new(code),
state,
Expand All @@ -897,13 +952,8 @@ impl<'a> ServerError<'a> {
}

/// Returns an sql state.
pub fn sql_state_ref(&self) -> [u8; 5] {
self.state
}

/// Returns an sql state as a string (lossy converted).
pub fn sql_state_str(&self) -> Cow<'_, str> {
String::from_utf8_lossy(&self.state[..])
pub fn sql_state_ref(&self) -> Option<&SqlState> {
self.state.as_ref()
}

/// Returns an error message.
Expand All @@ -927,43 +977,48 @@ impl<'a> ServerError<'a> {

impl<'de> MyDeserialize<'de> for ServerError<'de> {
const SIZE: Option<usize> = None;
/// An error packet error code.
type Ctx = u16;

fn deserialize(code: Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
match buf.0[0] {
b'#' => {
buf.skip(1);
Ok(ServerError {
code: RawInt::new(code),
state: buf.parse(())?,
message: buf.parse(())?,
})
/// An error packet error code + whether CLIENT_PROTOCOL_41 capability was negotiated.
type Ctx = (u16, bool);

fn deserialize((code, protocol_41): Self::Ctx, buf: &mut ParseBuf<'de>) -> io::Result<Self> {
let server_error = if protocol_41 {
ServerError {
code: RawInt::new(code),
state: Some(buf.parse(())?),
message: buf.parse(())?,
}
_ => Ok(ServerError {
} else {
ServerError {
code: RawInt::new(code),
state: *b"HY000",
state: None,
message: buf.parse(())?,
}),
}
}
};
Ok(server_error)
}
}

impl MySerialize for ServerError<'_> {
fn serialize(&self, buf: &mut Vec<u8>) {
buf.put_u8(b'#');
buf.put_slice(&self.state[..]);
if let Some(state) = &self.state {
state.serialize(buf);
}
self.message.serialize(buf);
}
}

impl fmt::Display for ServerError<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sql_state_str = self
.sql_state_ref()
.map(|s| format!(" ({})", s.as_str()))
.unwrap_or_default();

write!(
f,
"ERROR {} ({}): {}",
"ERROR {}{}: {}",
self.error_code(),
self.sql_state_str(),
sql_state_str,
self.message_str()
)
}
Expand Down Expand Up @@ -3570,21 +3625,22 @@ mod test {
\x6f\x6e\x6e\x65\x63\x74\x69\x6f\x6e\x73";
const PROGRESS_PACKET: &[u8] = b"\xff\xff\xff\x01\x01\x0a\xcc\x5b\x00\x0astage name";

let err_packet =
ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET)).unwrap();
let err_packet = err_packet.server_error();
assert_eq!(err_packet.error_code(), 1096);
assert_eq!(err_packet.sql_state_str(), "HY000");
assert_eq!(err_packet.message_str(), "No tables used");

let err_packet = ErrPacket::deserialize(
CapabilityFlags::CLIENT_PROTOCOL_41,
&mut ParseBuf(ERR_PACKET_NO_STATE),
&mut ParseBuf(ERR_PACKET),
)
.unwrap();
let err_packet = err_packet.server_error();
assert_eq!(err_packet.error_code(), 1096);
assert_eq!(err_packet.sql_state_ref().unwrap().as_str(), "HY000");
assert_eq!(err_packet.message_str(), "No tables used");

let err_packet =
ErrPacket::deserialize(CapabilityFlags::empty(), &mut ParseBuf(ERR_PACKET_NO_STATE))
.unwrap();
let server_error = err_packet.server_error();
assert_eq!(server_error.error_code(), 1040);
assert_eq!(server_error.sql_state_str(), "HY000");
assert_eq!(server_error.sql_state_ref(), None);
assert_eq!(server_error.message_str(), "Too many connections");

let err_packet = ErrPacket::deserialize(
Expand Down

0 comments on commit d9d512f

Please sign in to comment.