diff --git a/Cargo.lock b/Cargo.lock index 0f2d3a5..870ed01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -174,6 +174,7 @@ dependencies = [ "async-trait", "clap", "fern", + "glob", "log", "tokio", ] @@ -193,6 +194,12 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "heck" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index fc78402..b17fca2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ fern = "0.6" clap = { version = "4.1", features = ["derive"] } tokio = { version = "1.36", features = ["full", "tracing"] } async-trait = "0.1" +glob = "0.3.1" [profile.release] strip = true diff --git a/examples/insert_keys.rs b/examples/insert_keys.rs new file mode 100644 index 0000000..2de3bd0 --- /dev/null +++ b/examples/insert_keys.rs @@ -0,0 +1,47 @@ +use crabdis::error::Result; +use crabdis::storage::value::Value; +use crabdis::CLI; +use tokio::io::{AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; + +#[tokio::main] +async fn main() -> Result<()> { + let cli = CLI { + address: [127, 0, 0, 1].into(), + port: 6379, + threads: 1, + }; + + let connect_address = format!("{}:{}", cli.address, cli.port); + + let mut stream = TcpStream::connect(connect_address).await?; + let (mut reader, mut writer) = stream.split(); + let mut bufreader = BufReader::new(&mut reader); + + for i in 0..1000 { + let req = Value::Multi( + vec![ + Value::String("SET".into()), + Value::String(format!("key{i}")), + Value::String(format!("value{i}")), + ] + .into(), + ); + + println!("Sending request: {req:?}"); + + req.to_resp(&mut writer).await?; + + writer.flush().await?; + + let Some(resp) = Value::from_resp(&mut bufreader).await? else { + return Ok(()); + }; + + println!("Received response: {resp:?}"); + + assert_eq!(resp, Value::Ok); + } + + Ok(()) +} diff --git a/src/commands/core/keys.rs b/src/commands/core/keys.rs new file mode 100644 index 0000000..da4ae91 --- /dev/null +++ b/src/commands/core/keys.rs @@ -0,0 +1,41 @@ +use glob::Pattern; + +use crate::prelude::*; + +pub struct Keys; + +#[async_trait] +impl CommandTrait for Keys { + fn name(&self) -> &str { + "KEYS" + } + + async fn handle_command( + &self, + writer: &mut WriteHalf, + args: &mut VecDeque, + context: ContextRef, + ) -> Result<()> { + if args.len() != 1 { + return value_error!("Invalid number of arguments") + .to_resp(writer) + .await; + } + + let pattern = match args.pop_front() { + Some(Value::String(s)) => Pattern::new(&s)?, + _ => { + return value_error!("Invalid pattern").to_resp(writer).await; + } + }; + + let mut keys = VecDeque::new(); + for key in context.store.read().await.keys() { + if pattern.matches(key) { + keys.push_back(Value::String(key.clone())); + } + } + + Value::Multi(keys).to_resp(writer).await + } +} diff --git a/src/commands/core/mod.rs b/src/commands/core/mod.rs index e1a58ea..da4d8cb 100644 --- a/src/commands/core/mod.rs +++ b/src/commands/core/mod.rs @@ -2,6 +2,7 @@ mod del; mod exists; mod flushdb; mod get; +mod keys; mod mget; mod mset; mod ping; @@ -11,6 +12,7 @@ pub use del::Del; pub use exists::Exists; pub use flushdb::FlushDB; pub use get::Get; +pub use keys::Keys; pub use mget::MGet; pub use mset::MSet; pub use ping::Ping; diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 065b3b6..e8221a8 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -18,6 +18,14 @@ pub trait CommandTrait { ) -> Result<()>; } +macro_rules! register_commands { + ($handler:expr, $($command:expr),+ $(,)?) => { + $( + $handler.register_command($command).await; + )+ + }; +} + #[derive(Clone, Default)] pub struct CommandHandler { commands: Arc>>>, @@ -25,17 +33,21 @@ pub struct CommandHandler { impl CommandHandler { pub async fn register(&mut self) { - self.register_command(core::Get).await; - self.register_command(core::Set).await; - self.register_command(core::Del).await; - self.register_command(core::Ping).await; - self.register_command(core::MGet).await; - self.register_command(core::MSet).await; - self.register_command(core::Exists).await; - self.register_command(core::FlushDB).await; + register_commands!( + self, + core::Get, + core::Set, + core::Del, + core::MGet, + core::Ping, + core::MSet, + core::Keys, + core::Exists, + core::FlushDB, + ); } - pub async fn register_command(&mut self, command: C) + async fn register_command(&mut self, command: C) where C: CommandTrait + Send + Sync + 'static, { diff --git a/src/error.rs b/src/error.rs index 3f70ce5..c364c32 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,11 +2,14 @@ use std::error::Error as StdError; use std::fmt::{self, Display}; use std::io::Error as IoError; +use glob::PatternError; + pub type Result = std::result::Result; #[derive(Debug)] pub enum Error { Io(IoError), + Glob(PatternError), } impl From for Error { @@ -15,10 +18,17 @@ impl From for Error { } } +impl From for Error { + fn from(e: PatternError) -> Self { + Self::Glob(e) + } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Io(inner) => fmt::Display::fmt(&inner, f), + Self::Glob(inner) => fmt::Display::fmt(&inner, f), } } } @@ -27,6 +37,7 @@ impl StdError for Error { fn source(&self) -> Option<&(dyn StdError + 'static)> { match self { Self::Io(inner) => Some(inner), + Self::Glob(inner) => Some(inner), } } } diff --git a/src/handler.rs b/src/handler.rs index d062c1e..c518722 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -14,12 +14,18 @@ pub async fn handle_client(stream: &mut tokio::net::TcpStream, context: ContextR log::debug!("Received request: {request:?}"); match request { - Value::Multi(mut args) => { + Some(Value::Multi(mut args)) => { context .commands .handle_command(&mut writer, &mut args, context.clone()) .await? } + + // If the request is None, the client has disconnected. + None => { + return Ok(()); + } + _ => { value_error!("Invalid request").to_resp(&mut writer).await?; } diff --git a/src/lib.rs b/src/lib.rs index afa83ad..dd66523 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ mod context; pub mod error; mod handler; mod prelude; -mod storage; +pub mod storage; mod utils; use std::net::{IpAddr, SocketAddr}; diff --git a/src/storage/value.rs b/src/storage/value.rs index 55ff91f..a056179 100644 --- a/src/storage/value.rs +++ b/src/storage/value.rs @@ -106,7 +106,7 @@ impl Value { pub fn from_resp<'a, T>( reader: &'a mut BufReader<&mut T>, - ) -> Pin> + Send + 'a>> + ) -> Pin>> + Send + 'a>> where T: AsyncReadExt + Unpin + Send, { @@ -116,7 +116,7 @@ impl Value { reader.read_line(&mut line).await?; match line.chars().next() { - Some('$') if line == "$-1\r\n" => Ok(Self::Nil), + Some('$') if line == "$-1\r\n" => Ok(Some(Self::Nil)), Some('$') => { let len: usize = line[1..] @@ -130,9 +130,9 @@ impl Value { [value.pop(), value.pop()]; // remove \n\r - Ok(Self::String( + Ok(Some(Self::String( String::from_utf8(value).context("Could not parse string")?, - )) + ))) } Some(':') => { @@ -141,7 +141,7 @@ impl Value { .parse() .context("Could not parse integer")?; - Ok(Self::Integer(value)) + Ok(Some(Self::Integer(value))) } Some('*') => { @@ -152,13 +152,33 @@ impl Value { let mut values = VecDeque::with_capacity(len); for _ in 0..len { - values.push_back(Self::from_resp(reader).await?); + let value = Self::from_resp(reader).await?; + + if let Some(value) = value { + values.push_back(value); + } else { + return Ok(None); + } } - Ok(Self::Multi(values)) + Ok(Some(Self::Multi(values))) } - _ => Ok(Self::Error("Invalid response".to_string())), + Some('+') => { + let value = line[1..].trim(); + + match value { + "OK" => Ok(Some(Self::Ok)), + "PONG" => Ok(Some(Self::Pong)), + _ => unreachable!("Invalid response"), + } + } + + Some('-') => Ok(Some(Self::Error(line[1..].trim().to_string()))), + + None => Ok(None), + + _ => Ok(Some(Self::Error("Invalid response".to_string()))), } }) } @@ -257,21 +277,21 @@ mod tests { let mut reader = BufReader::new(&mut read); let value = Value::from_resp(&mut reader).await.unwrap(); - assert_eq!(value, Value::String("Hello, World!".to_string())); + assert_eq!(value, Some(Value::String("Hello, World!".to_string()))); let mut stream = create_tcp_stream(":42\r\n").await; let (mut read, _) = stream.split(); let mut reader = BufReader::new(&mut read); let value = Value::from_resp(&mut reader).await.unwrap(); - assert_eq!(value, Value::Integer(42)); + assert_eq!(value, Some(Value::Integer(42))); let mut stream = create_tcp_stream("$-1\r\n").await; let (mut read, _) = stream.split(); let mut reader = BufReader::new(&mut read); let value = Value::from_resp(&mut reader).await.unwrap(); - assert_eq!(value, Value::Nil); + assert_eq!(value, Some(Value::Nil)); let mut stream = create_tcp_stream("*3\r\n$13\r\nHello, World!\r\n:42\r\n$-1\r\n").await; let (mut read, _) = stream.split(); @@ -280,11 +300,11 @@ mod tests { let value = Value::from_resp(&mut reader).await.unwrap(); assert_eq!( value, - Value::Multi(VecDeque::from([ + Some(Value::Multi(VecDeque::from([ Value::String("Hello, World!".to_string()), Value::Integer(42), Value::Nil - ])) + ]))) ); let mut stream = create_tcp_stream("*2\r\n$3\r\nkey\r\n$5\r\nvalue\r\n").await; @@ -297,10 +317,10 @@ mod tests { assert_eq!( value, // it wont be a hashmap by default since there is no spec for hashmaps in RESP - Value::Multi(VecDeque::from([ + Some(Value::Multi(VecDeque::from([ Value::String("key".to_string()), Value::String("value".to_string()) - ])) + ]))) ); } }