diff --git a/Cargo.lock b/Cargo.lock index 75f8caf..0598ba9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,7 @@ dependencies = [ "anyhow", "bytes", "enum_dispatch", + "thiserror", ] [[package]] @@ -70,6 +71,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/Cargo.toml b/Cargo.toml index 89dcb2e..e0d8dc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,4 @@ edition = "2021" anyhow = "1.0.81" bytes = "1.6.0" enum_dispatch = "0.3.13" +thiserror = "1.0.58" diff --git a/src/resp/decode.rs b/src/resp/decode.rs index 4bd8276..465109c 100644 --- a/src/resp/decode.rs +++ b/src/resp/decode.rs @@ -1 +1,637 @@ -// empty file +/* +- 如何解析 Frame + - simple string: "+OK\r\n" + - error: "-Error message\r\n" + - bulk error: "!\r\n\r\n" + - integer: ":[<+|->]\r\n" + - bulk string: "$\r\n\r\n" + - null bulk string: "$-1\r\n" + - array: "*\r\n..." + - "*2\r\n$3\r\nget\r\n$5\r\nhello\r\n" + - null array: "*-1\r\n" + - null: "_\r\n" + - boolean: "#\r\n" + - double: ",[<+|->][.][[sign]]\r\n" + - map: "%\r\n..." + - set: "~\r\n..." + */ + +use crate::{ + BulkString, RespArray, RespDecode, RespError, RespFrame, RespMap, RespNull, RespNullArray, + RespNullBulkString, RespSet, SimpleError, SimpleString, +}; +use bytes::{Buf, BytesMut}; + +const CRLF: &[u8] = b"\r\n"; +const CRLF_LEN: usize = CRLF.len(); + +impl RespDecode for RespFrame { + const PREFIX: &'static str = ""; + fn decode(buf: &mut BytesMut) -> Result { + let mut iter = buf.iter().peekable(); + match iter.peek() { + Some(b'+') => { + let frame = SimpleString::decode(buf)?; + Ok(frame.into()) + } + Some(b'-') => { + let frame = SimpleError::decode(buf)?; + Ok(frame.into()) + } + Some(b':') => { + let frame = i64::decode(buf)?; + Ok(frame.into()) + } + Some(b'$') => { + // try null bulk string first + match RespNullBulkString::decode(buf) { + Ok(frame) => Ok(frame.into()), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => { + let frame = BulkString::decode(buf)?; + Ok(frame.into()) + } + } + } + Some(b'*') => { + // try null array first + match RespNullArray::decode(buf) { + Ok(frame) => Ok(frame.into()), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => { + let frame = RespArray::decode(buf)?; + Ok(frame.into()) + } + } + } + Some(b'_') => { + let frame = RespNull::decode(buf)?; + Ok(frame.into()) + } + Some(b'#') => { + let frame = bool::decode(buf)?; + Ok(frame.into()) + } + Some(b',') => { + let frame = f64::decode(buf)?; + Ok(frame.into()) + } + Some(b'%') => { + let frame = RespMap::decode(buf)?; + Ok(frame.into()) + } + Some(b'~') => { + let frame = RespSet::decode(buf)?; + Ok(frame.into()) + } + _ => Err(RespError::InvalidFrameType(format!( + "expect_length: unknown frame type: {:?}", + buf + ))), + } + } + + fn expect_length(buf: &[u8]) -> Result { + let mut iter = buf.iter().peekable(); + match iter.peek() { + Some(b'*') => RespArray::expect_length(buf), + Some(b'~') => RespSet::expect_length(buf), + Some(b'%') => RespMap::expect_length(buf), + Some(b'$') => BulkString::expect_length(buf), + Some(b':') => i64::expect_length(buf), + Some(b'+') => SimpleString::expect_length(buf), + Some(b'-') => SimpleError::expect_length(buf), + Some(b'#') => bool::expect_length(buf), + Some(b',') => f64::expect_length(buf), + Some(b'_') => RespNull::expect_length(buf), + _ => Err(RespError::NotComplete), + } + } +} + +impl RespDecode for SimpleString { + const PREFIX: &'static str = "+"; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(SimpleString::new(s.to_string())) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +impl RespDecode for SimpleError { + const PREFIX: &'static str = "-"; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(SimpleError::new(s.to_string())) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +impl RespDecode for RespNull { + const PREFIX: &'static str = "_"; + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "_\r\n", "Null")?; + Ok(RespNull) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(3) + } +} + +impl RespDecode for RespNullArray { + const PREFIX: &'static str = "*"; + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "*-1\r\n", "NullArray")?; + Ok(RespNullArray) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(4) + } +} + +impl RespDecode for RespNullBulkString { + const PREFIX: &'static str = "$"; + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "$-1\r\n", "NullBulkString")?; + Ok(RespNullBulkString) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(5) + } +} + +impl RespDecode for i64 { + const PREFIX: &'static str = ":"; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(s.parse()?) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +impl RespDecode for bool { + const PREFIX: &'static str = "#"; + fn decode(buf: &mut BytesMut) -> Result { + match extract_fixed_data(buf, "#t\r\n", "Bool") { + Ok(_) => Ok(true), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => match extract_fixed_data(buf, "#f\r\n", "Bool") { + Ok(_) => Ok(false), + Err(e) => Err(e), + }, + } + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(4) + } +} + +impl RespDecode for BulkString { + const PREFIX: &'static str = "$"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let remained = &buf[end + CRLF_LEN..]; + if remained.len() < len + CRLF_LEN { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let data = buf.split_to(len + CRLF_LEN); + Ok(BulkString::new(data[..len].to_vec())) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN + len + CRLF_LEN) + } +} + +// - array: "*\r\n..." +// - "*2\r\n$3\r\nget\r\n$5\r\nhello\r\n" +// FIXME: need to handle incomplete +impl RespDecode for RespArray { + const PREFIX: &'static str = "*"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = Vec::with_capacity(len); + for _ in 0..len { + frames.push(RespFrame::decode(buf)?); + } + + Ok(RespArray::new(frames)) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +// - double: ",[<+|->][.][[sign]]\r\n" +impl RespDecode for f64 { + const PREFIX: &'static str = ","; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(s.parse()?) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +// - map: "%\r\n..." +impl RespDecode for RespMap { + const PREFIX: &'static str = "%"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = RespMap::new(); + for _ in 0..len { + let key = SimpleString::decode(buf)?; + let value = RespFrame::decode(buf)?; + frames.insert(key.0, value); + } + + Ok(frames) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +// - set: "~\r\n..." +impl RespDecode for RespSet { + const PREFIX: &'static str = "~"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = Vec::new(); + for _ in 0..len { + frames.push(RespFrame::decode(buf)?); + } + + Ok(RespSet::new(frames)) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +fn extract_fixed_data( + buf: &mut BytesMut, + expect: &str, + expect_type: &str, +) -> Result<(), RespError> { + if buf.len() < expect.len() { + return Err(RespError::NotComplete); + } + + if !buf.starts_with(expect.as_bytes()) { + return Err(RespError::InvalidFrameType(format!( + "expect: {}, got: {:?}", + expect_type, buf + ))); + } + + buf.advance(expect.len()); + Ok(()) +} + +fn extract_simple_frame_data(buf: &[u8], prefix: &str) -> Result { + if buf.len() < 3 { + return Err(RespError::NotComplete); + } + + if !buf.starts_with(prefix.as_bytes()) { + return Err(RespError::InvalidFrameType(format!( + "expect: SimpleString({}), got: {:?}", + prefix, buf + ))); + } + + let end = find_crlf(buf, 1).ok_or(RespError::NotComplete)?; + + Ok(end) +} + +// find nth CRLF in the buffer +fn find_crlf(buf: &[u8], nth: usize) -> Option { + let mut count = 0; + for i in 1..buf.len() - 1 { + if buf[i] == b'\r' && buf[i + 1] == b'\n' { + count += 1; + if count == nth { + return Some(i); + } + } + } + + None +} + +fn parse_length(buf: &[u8], prefix: &str) -> Result<(usize, usize), RespError> { + let end = extract_simple_frame_data(buf, prefix)?; + let s = String::from_utf8_lossy(&buf[prefix.len()..end]); + Ok((end, s.parse()?)) +} + +fn calc_total_length(buf: &[u8], end: usize, len: usize, prefix: &str) -> Result { + let mut total = end + CRLF_LEN; + let mut data = &buf[total..]; + match prefix { + "*" | "~" => { + // find nth CRLF in the buffer, for array and set, we need to find 1 CRLF for each element + for _ in 0..len { + let len = RespFrame::expect_length(data)?; + data = &data[len..]; + total += len; + } + Ok(total) + } + "%" => { + // find nth CRLF in the buffer. For map, we need to find 2 CRLF for each key-value pair + for _ in 0..len { + let len = SimpleString::expect_length(data)?; + + data = &data[len..]; + total += len; + + let len = RespFrame::expect_length(data)?; + data = &data[len..]; + total += len; + } + Ok(total) + } + _ => Ok(len + CRLF_LEN), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + use bytes::BufMut; + + #[test] + fn test_simple_string_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"+OK\r\n"); + + let frame = SimpleString::decode(&mut buf)?; + assert_eq!(frame, SimpleString::new("OK".to_string())); + + buf.extend_from_slice(b"+hello\r"); + + let ret = SimpleString::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.put_u8(b'\n'); + let frame = SimpleString::decode(&mut buf)?; + assert_eq!(frame, SimpleString::new("hello".to_string())); + + Ok(()) + } + + #[test] + fn test_simple_error_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"-Error message\r\n"); + + let frame = SimpleError::decode(&mut buf)?; + assert_eq!(frame, SimpleError::new("Error message".to_string())); + + Ok(()) + } + + #[test] + fn test_integer_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b":+123\r\n"); + + let frame = i64::decode(&mut buf)?; + assert_eq!(frame, 123); + + buf.extend_from_slice(b":-123\r\n"); + + let frame = i64::decode(&mut buf)?; + assert_eq!(frame, -123); + + Ok(()) + } + + #[test] + fn test_bulk_string_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"$5\r\nhello\r\n"); + + let frame = BulkString::decode(&mut buf)?; + assert_eq!(frame, BulkString::new(b"hello")); + + buf.extend_from_slice(b"$5\r\nhello"); + let ret = BulkString::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.extend_from_slice(b"\r\n"); + let frame = BulkString::decode(&mut buf)?; + assert_eq!(frame, BulkString::new(b"hello")); + + Ok(()) + } + + #[test] + fn test_null_bulk_string_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"$-1\r\n"); + + let frame = RespNullBulkString::decode(&mut buf)?; + assert_eq!(frame, RespNullBulkString); + + Ok(()) + } + + #[test] + fn test_null_array_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*-1\r\n"); + + let frame = RespNullArray::decode(&mut buf)?; + assert_eq!(frame, RespNullArray); + + Ok(()) + } + + #[test] + fn test_null_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"_\r\n"); + + let frame = RespNull::decode(&mut buf)?; + assert_eq!(frame, RespNull); + + Ok(()) + } + + #[test] + fn test_boolean_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"#t\r\n"); + + let frame = bool::decode(&mut buf)?; + assert!(frame); + + buf.extend_from_slice(b"#f\r\n"); + + let frame = bool::decode(&mut buf)?; + assert!(!frame); + + buf.extend_from_slice(b"#f\r"); + let ret = bool::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.put_u8(b'\n'); + let frame = bool::decode(&mut buf)?; + assert!(!frame); + + Ok(()) + } + + #[test] + fn test_array_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*2\r\n$3\r\nset\r\n$5\r\nhello\r\n"); + + let frame = RespArray::decode(&mut buf)?; + assert_eq!(frame, RespArray::new([b"set".into(), b"hello".into()])); + + buf.extend_from_slice(b"*2\r\n$3\r\nset\r\n"); + let ret = RespArray::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.extend_from_slice(b"$5\r\nhello\r\n"); + let frame = RespArray::decode(&mut buf)?; + assert_eq!(frame, RespArray::new([b"set".into(), b"hello".into()])); + + Ok(()) + } + + #[test] + fn test_double_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b",123.45\r\n"); + + let frame = f64::decode(&mut buf)?; + assert_eq!(frame, 123.45); + + buf.extend_from_slice(b",+1.23456e-9\r\n"); + let frame = f64::decode(&mut buf)?; + assert_eq!(frame, 1.23456e-9); + + Ok(()) + } + + #[test] + fn test_map_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"%2\r\n+hello\r\n$5\r\nworld\r\n+foo\r\n$3\r\nbar\r\n"); + + let frame = RespMap::decode(&mut buf)?; + let mut map = RespMap::new(); + map.insert( + "hello".to_string(), + BulkString::new(b"world".to_vec()).into(), + ); + map.insert("foo".to_string(), BulkString::new(b"bar".to_vec()).into()); + assert_eq!(frame, map); + + Ok(()) + } + + #[test] + fn test_set_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"~2\r\n$3\r\nset\r\n$5\r\nhello\r\n"); + + let frame = RespSet::decode(&mut buf)?; + assert_eq!( + frame, + RespSet::new(vec![ + BulkString::new(b"set".to_vec()).into(), + BulkString::new(b"hello".to_vec()).into() + ]) + ); + + Ok(()) + } + + #[test] + fn test_calc_array_length() -> Result<()> { + let buf = b"*2\r\n$3\r\nset\r\n$5\r\nhello\r\n"; + let (end, len) = parse_length(buf, "*")?; + let total_len = calc_total_length(buf, end, len, "*")?; + assert_eq!(total_len, buf.len()); + + let buf = b"*2\r\n$3\r\nset\r\n"; + let (end, len) = parse_length(buf, "*")?; + let ret = calc_total_length(buf, end, len, "*"); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + Ok(()) + } +} diff --git a/src/resp/encode.rs b/src/resp/encode.rs index 2a291b7..1faa54b 100644 --- a/src/resp/encode.rs +++ b/src/resp/encode.rs @@ -1,21 +1,3 @@ -/* -- 如何解析 Frame - - simple string: "+OK\r\n" - - error: "-Error message\r\n" - - bulk error: "!\r\n\r\n" - - integer: ":[<+|->]\r\n" - - bulk string: "$\r\n\r\n" - - null bulk string: "$-1\r\n" - - array: "*\r\n..." - - "*2\r\n$3\r\nget\r\n$5\r\nhello\r\n" - - null array: "*-1\r\n" - - null: "_\r\n" - - boolean: "#\r\n" - - double: ",[<+|->][.][[sign]]\r\n" - - map: "%\r\n..." - - set: "~\r\n..." - */ - use crate::{ BulkString, RespArray, RespEncode, RespMap, RespNull, RespNullArray, RespNullBulkString, RespSet, SimpleError, SimpleString, diff --git a/src/resp/mod.rs b/src/resp/mod.rs index 345c595..8dbbc14 100644 --- a/src/resp/mod.rs +++ b/src/resp/mod.rs @@ -1,17 +1,40 @@ mod decode; mod encode; +use bytes::BytesMut; use enum_dispatch::enum_dispatch; use std::collections::BTreeMap; use std::ops::{Deref, DerefMut}; +use thiserror::Error; #[enum_dispatch] pub trait RespEncode { fn encode(self) -> Vec; } -pub trait RespDecode { - fn decode(buf: Self) -> Result; +pub trait RespDecode: Sized { + const PREFIX: &'static str; + fn decode(buf: &mut BytesMut) -> Result; + fn expect_length(buf: &[u8]) -> Result; +} + +#[derive(Error, Debug, PartialEq, Eq)] +pub enum RespError { + #[error("Invalid frame: {0}")] + InvalidFrame(String), + #[error("Invalid frame type: {0}")] + InvalidFrameType(String), + #[error("Invalid frame length: {0}")] + InvalidFrameLength(isize), + #[error("Frame is not complete")] + NotComplete, + + #[error("Parse error: {0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("Utf8 error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("Parse float error: {0}")] + ParseFloatError(#[from] std::num::ParseFloatError), } #[enum_dispatch(RespEncode)] @@ -146,3 +169,51 @@ impl RespSet { RespSet(s.into()) } } + +impl From<&str> for SimpleString { + fn from(s: &str) -> Self { + SimpleString(s.to_string()) + } +} + +impl From<&str> for RespFrame { + fn from(s: &str) -> Self { + SimpleString(s.to_string()).into() + } +} + +impl From<&str> for SimpleError { + fn from(s: &str) -> Self { + SimpleError(s.to_string()) + } +} + +impl From<&str> for BulkString { + fn from(s: &str) -> Self { + BulkString(s.as_bytes().to_vec()) + } +} + +impl From<&[u8]> for BulkString { + fn from(s: &[u8]) -> Self { + BulkString(s.to_vec()) + } +} + +impl From<&[u8]> for RespFrame { + fn from(s: &[u8]) -> Self { + BulkString(s.to_vec()).into() + } +} + +impl From<&[u8; N]> for BulkString { + fn from(s: &[u8; N]) -> Self { + BulkString(s.to_vec()) + } +} + +impl From<&[u8; N]> for RespFrame { + fn from(s: &[u8; N]) -> Self { + BulkString(s.to_vec()).into() + } +}