diff --git a/Cargo.toml b/Cargo.toml index 9cc3aa06..38896a17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mainline" -version = "1.1.0" +version = "1.2.0" authors = ["nuh.dev"] edition = "2018" description = "Simple, robust, BitTorrent's Mainline DHT implementation" @@ -22,6 +22,7 @@ flume = { version = "0.11.0", features = ["select", "eventual-fairness"], defaul ed25519-dalek = "2.1.0" bytes = "1.5.0" tracing = "0.1" +lru = { version = "0.12.2", default-features = false } [dev-dependencies] clap = { version = "4.4.8", features = ["derive"] } diff --git a/README.md b/README.md index c4aafffa..e1f92c33 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ let dht = Dht::client(); // or Dht::default(); Supported BEPs: - [x] [BEP0005 DHT Protocol](https://www.bittorrent.org/beps/bep_0005.html) -- [X] [BEP0042 DHT Security extension](https://www.bittorrent.org/beps/bep_0042.html) +- [x] [BEP0042 DHT Security extension](https://www.bittorrent.org/beps/bep_0042.html) - [x] [BEP0043 Read-only DHT Nodes](https://www.bittorrent.org/beps/bep_0043.html) - [x] [BEP0044 Storing arbitrary data in the DHT](https://www.bittorrent.org/beps/bep_0044.html) @@ -44,7 +44,7 @@ Supported BEPs: - [x] [BEP0005 DHT Protocol](https://www.bittorrent.org/beps/bep_0005.html) - [ ] [BEP0042 DHT Security extension](https://www.bittorrent.org/beps/bep_0042.html) - [x] [BEP0043 Read-only DHT Nodes](https://www.bittorrent.org/beps/bep_0043.html) -- [ ] [BEP0044 Storing arbitrary data in the DHT](https://www.bittorrent.org/beps/bep_0044.html) +- [x] [BEP0044 Storing arbitrary data in the DHT](https://www.bittorrent.org/beps/bep_0044.html) ## Acknowledgment diff --git a/src/async_dht.rs b/src/async_dht.rs index 6598e730..bea1b73d 100644 --- a/src/async_dht.rs +++ b/src/async_dht.rs @@ -2,12 +2,12 @@ use bytes::Bytes; -use crate::common::{ - hash_immutable, GetImmutableResponse, GetMutableResponse, GetPeerResponse, Id, MutableItem, - Node, Response, ResponseDone, ResponseMessage, StoreQueryMetdata, -}; +use crate::common::{hash_immutable, Id, MutableItem, Node, RoutingTable}; use crate::dht::ActorMessage; -use crate::routing_table::RoutingTable; +use crate::rpc::{ + GetImmutableResponse, GetMutableResponse, GetPeerResponse, Response, ResponseDone, + ResponseMessage, StoreQueryMetdata, +}; use crate::{Dht, Result}; use std::net::SocketAddr; @@ -202,3 +202,42 @@ impl Response { } } } + +#[cfg(test)] +mod test { + use std::time::Duration; + + use super::*; + use crate::Testnet; + + #[cfg(feature = "async")] + #[test] + fn announce_get_peer_async() { + async fn test() { + let testnet = Testnet::new(10); + + let a = Dht::builder() + .bootstrap(&testnet.bootstrap) + .build() + .as_async(); + let b = Dht::builder() + .bootstrap(&testnet.bootstrap) + .build() + .as_async(); + + let info_hash = Id::random(); + + match a.announce_peer(info_hash, Some(45555)).await { + Ok(_) => { + if let Some(r) = b.get_peers(info_hash).next_async().await { + assert_eq!(r.peer.port(), 45555); + } else { + panic!("No respnoses") + } + } + Err(_) => {} + }; + } + futures::executor::block_on(test()); + } +} diff --git a/src/common/id.rs b/src/common/id.rs index af3ec11e..ca5d5d7d 100644 --- a/src/common/id.rs +++ b/src/common/id.rs @@ -75,7 +75,7 @@ impl Id { let mut rng = rand::thread_rng(); let r: u8 = rng.gen(); - let mut bytes: [u8; 20] = rng.gen(); + let bytes: [u8; 20] = rng.gen(); match ip { IpAddr::V4(addr) => from_ipv4_and_r(bytes, addr, r), @@ -96,7 +96,7 @@ impl Id { expected == actual } - IpAddr::V6(ipv6) => { + IpAddr::V6(_ipv6) => { unimplemented!() // // For IPv6, checking the ULA range fc00::/7 diff --git a/src/common/mod.rs b/src/common/mod.rs index 7ff6b95e..b52dc16a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -4,10 +4,10 @@ mod id; mod immutable; mod mutable; mod node; -mod response; +mod routing_table; pub use id::*; pub use immutable::*; pub use mutable::*; pub use node::*; -pub use response::*; +pub use routing_table::*; diff --git a/src/common/mutable.rs b/src/common/mutable.rs index 5969e5d5..727f3a79 100644 --- a/src/common/mutable.rs +++ b/src/common/mutable.rs @@ -21,6 +21,8 @@ pub struct MutableItem { signature: [u8; 64], /// Optional salt salt: Option, + /// Optional compare and swap seq + cas: Option, } impl MutableItem { @@ -38,6 +40,12 @@ impl MutableItem { ) } + /// Set the cas number if needed. + pub fn with_cas(mut self, cas: i64) -> Self { + self.cas = Some(cas); + self + } + /// Create a new mutable item from an already signed value. pub fn new_signed_unchecked( key: [u8; 32], @@ -53,6 +61,7 @@ impl MutableItem { seq, signature, salt, + cas: None, } } @@ -63,6 +72,7 @@ impl MutableItem { seq: &i64, signature: &[u8], salt: &Option, + cas: &Option, ) -> Result { let key = VerifyingKey::try_from(key).map_err(|_| Error::InvalidMutablePublicKey)?; @@ -79,6 +89,7 @@ impl MutableItem { seq: *seq, signature: signature.to_bytes(), salt: salt.to_owned(), + cas: *cas, }) } @@ -107,6 +118,10 @@ impl MutableItem { pub fn salt(&self) -> &Option { &self.salt } + + pub fn cas(&self) -> &Option { + &self.cas + } } pub fn target_from_key(public_key: &[u8; 32], salt: &Option) -> Id { diff --git a/src/routing_table.rs b/src/common/routing_table.rs similarity index 99% rename from src/routing_table.rs rename to src/common/routing_table.rs index 81721eeb..bb761a07 100644 --- a/src/routing_table.rs +++ b/src/common/routing_table.rs @@ -230,10 +230,7 @@ mod test { use std::net::SocketAddr; use std::str::FromStr; - use crate::{ - common::{Id, Node}, - routing_table::{KBucket, RoutingTable, MAX_BUCKET_SIZE_K}, - }; + use crate::common::{Id, KBucket, Node, RoutingTable, MAX_BUCKET_SIZE_K}; #[test] fn table_is_empty() { diff --git a/src/dht.rs b/src/dht.rs index 4bf572ee..86a25306 100644 --- a/src/dht.rs +++ b/src/dht.rs @@ -9,12 +9,11 @@ use bytes::Bytes; use flume::{Receiver, Sender}; use crate::{ - common::{ - hash_immutable, target_from_key, GetImmutableResponse, GetMutableResponse, GetPeerResponse, - Id, MutableItem, Node, Response, ResponseMessage, ResponseSender, StoreQueryMetdata, + common::{hash_immutable, target_from_key, Id, MutableItem, Node, RoutingTable}, + rpc::{ + GetImmutableResponse, GetMutableResponse, GetPeerResponse, Response, ResponseMessage, + ResponseSender, Rpc, StoreQueryMetdata, }, - routing_table::RoutingTable, - rpc::Rpc, Result, }; @@ -460,8 +459,11 @@ impl Testnet { #[cfg(test)] mod test { + use std::str::FromStr; use std::time::Duration; + use ed25519_dalek::SigningKey; + use super::*; #[test] @@ -478,6 +480,18 @@ mod test { dht.block_until_shutdown(); } + #[test] + fn bind_twice() { + let a = Dht::default(); + let b = Dht::builder() + .port(a.local_addr().unwrap().port()) + .as_server() + .build(); + + let result = b.handle.unwrap().join(); + assert!(result.is_err()); + } + #[test] fn announce_get_peer() { let testnet = Testnet::new(10); @@ -504,46 +518,75 @@ mod test { }; } - #[cfg(feature = "async")] #[test] - fn announce_get_peer_async() { - async fn test() { - let testnet = Testnet::new(10); - - let a = Dht::builder() - .bootstrap(&testnet.bootstrap) - .build() - .as_async(); - let b = Dht::builder() - .bootstrap(&testnet.bootstrap) - .build() - .as_async(); - - let info_hash = Id::random(); - - match a.announce_peer(info_hash, Some(45555)).await { - Ok(_) => { - if let Some(r) = b.get_peers(info_hash).next_async().await { - assert_eq!(r.peer.port(), 45555); - } else { + fn put_get_immutable() { + let testnet = Testnet::new(10); + + let a = Dht::builder().bootstrap(&testnet.bootstrap).build(); + let b = Dht::builder().bootstrap(&testnet.bootstrap).build(); + + let value: Bytes = "Hello World!".into(); + let expected_target = Id::from_str("e5f96f6f38320f0f33959cb4d3d656452117aadb").unwrap(); + + match a.put_immutable(value.clone()) { + Ok(result) => { + assert_ne!(result.stored_at().len(), 0); + assert_eq!(result.target(), expected_target); + + let responses: Vec<_> = b.get_immutable(result.target()).collect(); + + match responses.first() { + Some(r) => { + assert_eq!(r.value, value); + } + None => { panic!("No respnoses") } } - Err(_) => {} - }; - } - futures::executor::block_on(test()); + } + Err(_) => { + panic!("Expected put_immutable to succeeed") + } + }; } #[test] - fn bind_twice() { - let a = Dht::default(); - let b = Dht::builder() - .port(a.local_addr().unwrap().port()) - .as_server() - .build(); + fn put_get_mutable() { + let testnet = Testnet::new(10); - let result = b.handle.unwrap().join(); - assert!(result.is_err()); + let a = Dht::builder().bootstrap(&testnet.bootstrap).build(); + let b = Dht::builder().bootstrap(&testnet.bootstrap).build(); + + let signer = SigningKey::from_bytes(&[ + 56, 171, 62, 85, 105, 58, 155, 209, 189, 8, 59, 109, 137, 84, 84, 201, 221, 115, 7, + 228, 127, 70, 4, 204, 182, 64, 77, 98, 92, 215, 27, 103, + ]); + + let seq = 1000; + let value: Bytes = "Hello World!".into(); + + let item = MutableItem::new(signer.clone(), value, seq, None); + + match a.put_mutable(item.clone()) { + Ok(result) => { + assert_ne!(result.stored_at().len(), 0); + + let responses: Vec<_> = b + .get_mutable(signer.verifying_key().as_bytes(), None) + .collect(); + + match responses.first() { + Some(r) => { + assert_eq!(&r.item, &item); + } + None => { + panic!("No respnoses") + } + } + } + Err(_) => { + panic!("Expected put_immutable to succeeed") + } + }; } } diff --git a/src/lib.rs b/src/lib.rs index 6e5b3517..6e290a9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,12 +11,7 @@ pub mod common; pub mod dht; pub mod error; mod messages; -mod peers; -mod query; -mod routing_table; mod rpc; -mod socket; -mod tokens; pub use crate::common::Id; pub use crate::error::Error; diff --git a/src/messages/internal.rs b/src/messages/internal.rs index 2a53a75b..61ebe2dd 100644 --- a/src/messages/internal.rs +++ b/src/messages/internal.rs @@ -78,13 +78,13 @@ pub enum DHTRequestSpecific { #[serde(rename = "get")] GetValue { #[serde(rename = "a")] - arguments: DHTGetValueArguments, + arguments: DHTGetValueRequestArguments, }, #[serde(rename = "put")] PutValue { #[serde(rename = "a")] - arguments: DHTPutValueArguments, + arguments: DHTPutValueRequestArguments, }, } @@ -232,12 +232,15 @@ pub struct DHTAnnouncePeerRequestArguments { // === Get Value === #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct DHTGetValueArguments { +pub struct DHTGetValueRequestArguments { #[serde(with = "serde_bytes")] pub id: Vec, #[serde(with = "serde_bytes")] pub target: Vec, + + #[serde(default)] + pub seq: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -283,7 +286,7 @@ pub struct DHTGetMutableResponseArguments { // === Put Value === #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct DHTPutValueArguments { +pub struct DHTPutValueRequestArguments { #[serde(with = "serde_bytes")] pub id: Vec, @@ -297,13 +300,20 @@ pub struct DHTPutValueArguments { pub v: Vec, #[serde(with = "serde_bytes")] + #[serde(default)] pub k: Option>, #[serde(with = "serde_bytes")] + #[serde(default)] pub sig: Option>, + #[serde(default)] pub seq: Option, + #[serde(default)] + pub cas: Option, + #[serde(with = "serde_bytes")] + #[serde(default)] pub salt: Option>, } diff --git a/src/messages/mod.rs b/src/messages/mod.rs index 6ade0b51..3b07e81d 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -51,10 +51,9 @@ pub enum RequestSpecific { FindNode(FindNodeRequestArguments), GetPeers(GetPeersRequestArguments), AnnouncePeer(AnnouncePeerRequestArguments), - GetValue(GetValueRequestArguments), PutImmutable(PutImmutableRequestArguments), - GetMutable(GetMutableRequestArguments), PutMutable(PutMutableRequestArguments), + GetValue(GetValueRequestArguments), } #[derive(Debug, PartialEq, Clone)] @@ -97,6 +96,11 @@ pub struct FindNodeResponseArguments { pub struct GetValueRequestArguments { pub requester_id: Id, pub target: Id, + pub seq: Option, + // A bit of a hack, using this to carry an optional + // salt in the query.request field of [crate::query] + // not really encoded, decoded or sent over the wire. + pub salt: Option, } #[derive(Debug, PartialEq, Clone)] @@ -145,16 +149,6 @@ pub struct GetImmutableResponseArguments { // === Get Mutable === -#[derive(Debug, PartialEq, Clone)] -pub struct GetMutableRequestArguments { - pub requester_id: Id, - pub target: Id, - // A bit of a hack, using this to carry an optional - // salt in the query.request field of [crate::query] - // not really encoded, decoded or sent over the wire. - pub salt: Option, -} - #[derive(Debug, PartialEq, Clone)] pub struct GetMutableResponseArguments { pub responder_id: Id, @@ -188,6 +182,7 @@ pub struct PutMutableRequestArguments { pub seq: i64, pub sig: Vec, pub salt: Option>, + pub cas: Option, } impl Message { @@ -237,17 +232,9 @@ impl Message { }, } } - RequestSpecific::GetValue(get_value_arguments) => { - internal::DHTRequestSpecific::GetValue { - arguments: internal::DHTGetValueArguments { - id: get_value_arguments.requester_id.to_vec(), - target: get_value_arguments.target.to_vec(), - }, - } - } RequestSpecific::PutImmutable(put_immutable_arguments) => { internal::DHTRequestSpecific::PutValue { - arguments: internal::DHTPutValueArguments { + arguments: internal::DHTPutValueRequestArguments { id: put_immutable_arguments.requester_id.to_vec(), target: put_immutable_arguments.target.to_vec(), token: put_immutable_arguments.token, @@ -256,28 +243,31 @@ impl Message { seq: None, sig: None, salt: None, + cas: None, }, } } - RequestSpecific::GetMutable(get_mutable_args) => { + RequestSpecific::GetValue(get_mutable_args) => { internal::DHTRequestSpecific::GetValue { - arguments: internal::DHTGetValueArguments { + arguments: internal::DHTGetValueRequestArguments { id: get_mutable_args.requester_id.to_vec(), target: get_mutable_args.target.to_vec(), + seq: get_mutable_args.seq, }, } } RequestSpecific::PutMutable(put_mutable_arguments) => { internal::DHTRequestSpecific::PutValue { - arguments: internal::DHTPutValueArguments { + arguments: internal::DHTPutValueRequestArguments { id: put_mutable_arguments.requester_id.to_vec(), target: put_mutable_arguments.target.to_vec(), token: put_mutable_arguments.token, v: put_mutable_arguments.v, - k: Some(put_mutable_arguments.k), + k: Some(put_mutable_arguments.k.to_vec()), seq: Some(put_mutable_arguments.seq), sig: Some(put_mutable_arguments.sig), salt: put_mutable_arguments.salt, + cas: put_mutable_arguments.cas, }, } } @@ -417,6 +407,8 @@ impl Message { RequestSpecific::GetValue(GetValueRequestArguments { requester_id: Id::from_bytes(arguments.id)?, target: Id::from_bytes(arguments.target)?, + seq: arguments.seq, + salt: None, }) } internal::DHTRequestSpecific::PutValue { arguments } => { @@ -431,6 +423,7 @@ impl Message { seq: arguments.seq.unwrap(), sig: arguments.sig.unwrap(), salt: arguments.salt, + cas: arguments.cas, }) } else { RequestSpecific::PutImmutable(PutImmutableRequestArguments { @@ -563,10 +556,9 @@ impl Message { RequestSpecific::FindNode(arguments) => arguments.requester_id, RequestSpecific::GetPeers(arguments) => arguments.requester_id, RequestSpecific::AnnouncePeer(arguments) => arguments.requester_id, - RequestSpecific::GetValue(arguments) => arguments.requester_id, RequestSpecific::PutImmutable(arguments) => arguments.requester_id, - RequestSpecific::GetMutable(arguments) => arguments.requester_id, RequestSpecific::PutMutable(arguments) => arguments.requester_id, + RequestSpecific::GetValue(arguments) => arguments.requester_id, }, MessageType::Response(response_variant) => match response_variant { ResponseSpecific::Ping(arguments) => arguments.responder_id, @@ -970,6 +962,8 @@ mod tests { GetValueRequestArguments { requester_id: Id::random(), target: Id::random(), + seq: Some(1231), + salt: None, }, )), }; @@ -1004,4 +998,57 @@ mod tests { let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap(); assert_eq!(parsed_msg, original_msg); } + + #[test] + fn test_put_immutable_request() { + let original_msg = Message { + transaction_id: 3, + version: Some(vec![1]), + requester_ip: Some("50.51.52.53:5455".parse().unwrap()), + read_only: false, + message_type: MessageType::Request(RequestSpecific::PutImmutable( + PutImmutableRequestArguments { + requester_id: Id::random(), + target: Id::random(), + token: vec![99, 100, 101, 102], + v: vec![99, 100, 101, 102], + }, + )), + }; + + let serde_msg = original_msg.clone().into_serde_message(); + let bytes = serde_msg.to_bytes().unwrap(); + let parsed_serde_msg = internal::DHTMessage::from_bytes(bytes).unwrap(); + let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap(); + assert_eq!(parsed_msg, original_msg); + } + + #[test] + fn test_put_mutable_request() { + let original_msg = Message { + transaction_id: 3, + version: Some(vec![1]), + requester_ip: Some("50.51.52.53:5455".parse().unwrap()), + read_only: false, + message_type: MessageType::Request(RequestSpecific::PutMutable( + PutMutableRequestArguments { + requester_id: Id::random(), + target: Id::random(), + token: vec![99, 100, 101, 102], + v: vec![99, 100, 101, 102], + k: vec![100, 101, 102, 103], + seq: 100, + sig: vec![0, 1, 2, 3], + salt: Some(vec![0, 2, 4, 8]), + cas: Some(100), + }, + )), + }; + + let serde_msg = original_msg.clone().into_serde_message(); + let bytes = serde_msg.to_bytes().unwrap(); + let parsed_serde_msg = internal::DHTMessage::from_bytes(bytes).unwrap(); + let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap(); + assert_eq!(parsed_msg, original_msg); + } } diff --git a/src/peers.rs b/src/peers.rs deleted file mode 100644 index 8f5d2573..00000000 --- a/src/peers.rs +++ /dev/null @@ -1,113 +0,0 @@ -//! Manage announced peers for info_hashes - -use std::{collections::HashMap, net::SocketAddr}; - -use rand::{rngs::ThreadRng, seq::SliceRandom, thread_rng}; - -use crate::{routing_table::MAX_BUCKET_SIZE_K, Id}; - -const MAX_PEERS: usize = 10000; -const MAX_PEERS_PER_INFO_HASH: usize = 100; - -#[derive(Debug)] -pub struct PeersStore { - rng: ThreadRng, - lru: Vec<(Id, SocketAddr)>, - peers: HashMap>, -} - -impl PeersStore { - pub fn new() -> Self { - Self { - rng: thread_rng(), - lru: Vec::new(), - peers: HashMap::new(), - } - } - - pub fn add_peer(&mut self, info_hash: Id, peer: SocketAddr) { - let incoming = (info_hash, peer); - - // If the item is already in the LRU, bring it to the end. - // This is where we ensure unique (info_hash, peer) tuple in the store. - if let Some(index) = self.lru.iter().position(|i| i == &incoming) { - // Bring the item to the end of the LRU - self.lru.remove(index); - self.lru.push(incoming); - return; - } - - // If the LRU is full, remove the oldest item. - if self.lru.len() >= MAX_PEERS { - let (info_hash, _) = self.lru.remove(0); - let peers = self.peers.get_mut(&info_hash).unwrap(); - peers.remove(0); - } - - let info_hash_peers = self.peers.entry(info_hash).or_default(); - - // If the info_hash is full of peers, remove the oldest one. - if info_hash_peers.len() >= MAX_PEERS_PER_INFO_HASH { - let peer = info_hash_peers.remove(0); - - self.lru.retain(|i| i != &(info_hash, peer)) - } - - // Add the new item to the info_hash_peers - info_hash_peers.push(peer); - self.lru.push(incoming); - } - - pub fn get_random_peers(&mut self, info_hash: &Id) -> Option> { - if let Some(peers) = self.peers.get(info_hash) { - let peers = peers.clone(); - - let random_20: Vec = peers - .choose_multiple(&mut self.rng, MAX_BUCKET_SIZE_K) - .cloned() - .collect(); - - return Some(random_20); - } - - None - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn lru() { - let mut store = PeersStore::new(); - - let info_hash_a = Id::random(); - let info_hash_b = Id::random(); - - store.add_peer(info_hash_a, SocketAddr::from(([127, 0, 1, 1], 0))); - store.add_peer(info_hash_a, SocketAddr::from(([127, 0, 1, 1], 0))); - store.add_peer(info_hash_a, SocketAddr::from(([127, 0, 1, 1], 0))); - - store.add_peer(info_hash_b, SocketAddr::from(([127, 0, 2, 1], 0))); - - store.add_peer(info_hash_a, SocketAddr::from(([127, 0, 1, 2], 0))); - store.add_peer(info_hash_b, SocketAddr::from(([127, 0, 2, 2], 0))); - store.add_peer(info_hash_a, SocketAddr::from(([127, 0, 1, 2], 0))); - - let mut order: Vec<(Id, SocketAddr)> = Vec::new(); - - for item in &store.lru { - order.push(*item); - } - - let expected = vec![ - (info_hash_a, SocketAddr::from(([127, 0, 1, 1], 0))), - (info_hash_b, SocketAddr::from(([127, 0, 2, 1], 0))), - (info_hash_b, SocketAddr::from(([127, 0, 2, 2], 0))), - (info_hash_a, SocketAddr::from(([127, 0, 1, 2], 0))), - ]; - - assert_eq!(order, expected); - } -} diff --git a/src/rpc.rs b/src/rpc/mod.rs similarity index 77% rename from src/rpc.rs rename to src/rpc/mod.rs index 35b625ca..fd1c1a5c 100644 --- a/src/rpc.rs +++ b/src/rpc/mod.rs @@ -1,28 +1,36 @@ +//! K-RPC implementation + +mod query; +pub mod response; +mod server; +mod socket; + use std::collections::HashMap; use std::net::{SocketAddr, ToSocketAddrs}; +use std::num::NonZeroUsize; use std::time::{Duration, Instant}; use bytes::Bytes; +use lru::LruCache; use tracing::{debug, error}; -use crate::common::{ - validate_immutable, GetImmutableResponse, GetMutableResponse, GetPeerResponse, Id, MutableItem, - Node, ResponseSender, ResponseValue, -}; +use crate::common::{validate_immutable, Id, MutableItem, Node, RoutingTable}; use crate::messages::{ - AnnouncePeerRequestArguments, FindNodeRequestArguments, FindNodeResponseArguments, - GetImmutableResponseArguments, GetMutableRequestArguments, GetMutableResponseArguments, - GetPeersRequestArguments, GetPeersResponseArguments, GetValueRequestArguments, Message, - MessageType, NoValuesResponseArguments, PingRequestArguments, PingResponseArguments, + AnnouncePeerRequestArguments, FindNodeRequestArguments, GetImmutableResponseArguments, + GetMutableResponseArguments, GetPeersRequestArguments, GetPeersResponseArguments, + GetValueRequestArguments, Message, MessageType, PingRequestArguments, PingResponseArguments, PutImmutableRequestArguments, PutMutableRequestArguments, RequestSpecific, ResponseSpecific, }; -use crate::peers::PeersStore; -use crate::query::{Query, StoreQuery}; -use crate::routing_table::RoutingTable; -use crate::socket::KrpcSocket; -use crate::tokens::Tokens; +pub use response::{ + GetImmutableResponse, GetMutableResponse, GetPeerResponse, Response, ResponseDone, + ResponseMessage, ResponseSender, ResponseValue, StoreQueryMetdata, +}; + use crate::Result; +use query::{Query, StoreQuery}; +use server::{handle_request, PeersStore, Tokens}; +use socket::KrpcSocket; const DEFAULT_BOOTSTRAP_NODES: [&str; 4] = [ "router.bittorrent.com:6881", @@ -34,6 +42,11 @@ const DEFAULT_BOOTSTRAP_NODES: [&str; 4] = [ const REFRESH_TABLE_INTERVAL: Duration = Duration::from_secs(15 * 60); const PING_TABLE_INTERVAL: Duration = Duration::from_secs(5 * 60); +// Stored data in server mode. +const MAX_INFO_HASHES: usize = 2000; +const MAX_PEERS: usize = 500; +const MAX_VALUES: usize = 1000; + #[derive(Debug)] pub struct Rpc { socket: KrpcSocket, @@ -41,7 +54,10 @@ pub struct Rpc { queries: HashMap, store_queries: HashMap, tokens: Tokens, + peers: PeersStore, + immutable_values: LruCache, + mutable_values: LruCache, /// Last time we refreshed the routing table with a find_node query. last_table_refresh: Instant, @@ -55,7 +71,7 @@ pub struct Rpc { impl Rpc { pub fn new() -> Result { - // TODO: One day I might implement BEP42. + // TODO: One day I might implement BEP42 on Routing nodes. let id = Id::random(); let socket = KrpcSocket::new()?; @@ -71,7 +87,13 @@ impl Rpc { queries: HashMap::new(), store_queries: HashMap::new(), tokens: Tokens::new(), - peers: PeersStore::new(), + + peers: PeersStore::new( + NonZeroUsize::new(MAX_INFO_HASHES).unwrap(), + NonZeroUsize::new(MAX_PEERS).unwrap(), + ), + immutable_values: LruCache::new(NonZeroUsize::new(MAX_VALUES).unwrap()), + mutable_values: LruCache::new(NonZeroUsize::new(MAX_VALUES).unwrap()), last_table_refresh: Instant::now() - REFRESH_TABLE_INTERVAL, last_table_ping: Instant::now(), @@ -122,6 +144,11 @@ impl Rpc { // === Public Methods === pub fn tick(&mut self) { + // === Tokens === + if self.tokens.should_update() { + self.tokens.rotate() + } + // === Tick Queries === for (_, query) in self.queries.iter_mut() { query.tick(&mut self.socket); @@ -161,13 +188,13 @@ impl Rpc { match &message.message_type { MessageType::Request(request_specific) => { - self.handle_request(from, message.transaction_id, request_specific); + handle_request(self, from, message.transaction_id, request_specific) } MessageType::Response(_) => { self.handle_response(from, &message); } MessageType::Error(error) => { - debug!(?message, "RPC Error response"); + debug!(?error, "RPC Error response"); } } }; @@ -225,6 +252,8 @@ impl Rpc { RequestSpecific::GetValue(GetValueRequestArguments { requester_id: self.id, target, + seq: None, + salt: None, }), Some(sender), ) @@ -260,9 +289,10 @@ impl Rpc { pub fn get_mutable(&mut self, target: Id, salt: Option, sender: ResponseSender) { self.query( target, - RequestSpecific::GetMutable(GetMutableRequestArguments { + RequestSpecific::GetValue(GetValueRequestArguments { requester_id: self.id, target, + seq: None, salt, }), Some(sender), @@ -285,6 +315,7 @@ impl Rpc { seq: *item.seq(), sig: item.signature().to_vec(), salt: item.salt().clone().map(|s| s.to_vec()), + cas: *item.cas(), }), &mut self.socket, ); @@ -342,86 +373,6 @@ impl Rpc { self.queries.insert(target, query); } - fn handle_request(&mut self, from: SocketAddr, transaction_id: u16, request: &RequestSpecific) { - match request { - // TODO: Handle bad requests (send an error message). - RequestSpecific::Ping(_) => { - self.socket.response( - from, - transaction_id, - ResponseSpecific::Ping(PingResponseArguments { - responder_id: self.id, - }), - ); - } - RequestSpecific::FindNode(FindNodeRequestArguments { target, .. }) => { - self.socket.response( - from, - transaction_id, - ResponseSpecific::FindNode(FindNodeResponseArguments { - responder_id: self.id, - nodes: self.routing_table.closest(target), - }), - ); - } - RequestSpecific::GetPeers(GetPeersRequestArguments { info_hash, .. }) => { - self.socket.response( - from, - transaction_id, - match self.peers.get_random_peers(info_hash) { - Some(peers) => ResponseSpecific::GetPeers(GetPeersResponseArguments { - responder_id: self.id, - token: self.tokens.generate_token(from).into(), - nodes: Some(self.routing_table.closest(info_hash)), - values: peers, - }), - None => ResponseSpecific::NoValues(NoValuesResponseArguments { - responder_id: self.id, - token: self.tokens.generate_token(from).into(), - nodes: Some(self.routing_table.closest(info_hash)), - }), - }, - ); - } - RequestSpecific::AnnouncePeer(AnnouncePeerRequestArguments { - info_hash, - port, - implied_port, - token, - .. - }) => { - if self.tokens.validate(from, token) { - let peer = match implied_port { - Some(true) => from, - _ => SocketAddr::new(from.ip(), *port), - }; - - self.peers.add_peer(*info_hash, peer); - - self.socket.response( - from, - transaction_id, - ResponseSpecific::Ping(PingResponseArguments { - responder_id: self.id, - }), - ); - } else { - // TODO: Send an error message. - } - } - RequestSpecific::PutImmutable(PutImmutableRequestArguments { v, target, .. }) => { - if v.len() > 1000 || !validate_immutable(v, target) { - // TODO: return and log error. - } - // TODO: store immutable items. - } - _ => { - // TODO: How to deal with unknown requests? - // Maybe just return CloserNodesAndToken to the sender? - } - } - } - fn handle_response(&mut self, from: SocketAddr, message: &Message) { if message.read_only { return; @@ -438,6 +389,7 @@ impl Rpc { responder_id, })) = message.message_type { + // Mark storage at that node as a success. query.success(responder_id); } @@ -501,9 +453,7 @@ impl Rpc { }, )) => { let salt = match query.request() { - RequestSpecific::GetMutable(GetMutableRequestArguments { - salt, .. - }) => salt, + RequestSpecific::GetValue(GetValueRequestArguments { salt, .. }) => salt, _ => &None, }; let target = query.target(); @@ -515,6 +465,7 @@ impl Rpc { seq, sig, salt, + &None, ) { query.response(ResponseValue::Mutable(GetMutableResponse { from: Node::new(*responder_id, from), diff --git a/src/query.rs b/src/rpc/query.rs similarity index 97% rename from src/query.rs rename to src/rpc/query.rs index 97821156..53d02f1e 100644 --- a/src/query.rs +++ b/src/rpc/query.rs @@ -3,12 +3,12 @@ use std::collections::HashSet; use std::net::SocketAddr; -use crate::common::{ - Id, Node, ResponseDone, ResponseMessage, ResponseSender, ResponseValue, StoreQueryMetdata, +use super::response::{ + ResponseDone, ResponseMessage, ResponseSender, ResponseValue, StoreQueryMetdata, }; +use super::socket::KrpcSocket; +use crate::common::{Id, Node, RoutingTable}; use crate::messages::RequestSpecific; -use crate::routing_table::RoutingTable; -use crate::socket::KrpcSocket; /// A query is an iterative process of concurrently sending a request to the closest known nodes to /// the target, updating the routing table with closer nodes discovered in the responses, and @@ -108,7 +108,7 @@ impl Query { self.with_token.add(node.clone()); } - /// Add reveived response + /// Add received response pub fn response(&mut self, response: ResponseValue) { self.responses.push(response.clone()); diff --git a/src/common/response.rs b/src/rpc/response.rs similarity index 100% rename from src/common/response.rs rename to src/rpc/response.rs diff --git a/src/rpc/server/mod.rs b/src/rpc/server/mod.rs new file mode 100644 index 00000000..f807d5a4 --- /dev/null +++ b/src/rpc/server/mod.rs @@ -0,0 +1,9 @@ +//! Modules needed only for nodes running in server mode (not read-only). + +pub mod peers; +pub mod request; +pub mod tokens; + +pub use peers::*; +pub use request::*; +pub use tokens::*; diff --git a/src/rpc/server/peers.rs b/src/rpc/server/peers.rs new file mode 100644 index 00000000..750c0705 --- /dev/null +++ b/src/rpc/server/peers.rs @@ -0,0 +1,160 @@ +//! Manage announced peers for info_hashes + +use std::{net::SocketAddr, num::NonZeroUsize}; + +use rand::{rngs::ThreadRng, thread_rng, Rng}; + +use crate::common::Id; + +use lru::LruCache; + +#[derive(Debug)] +pub struct PeersStore { + rng: ThreadRng, + info_hashes: LruCache>, + max_peers: NonZeroUsize, +} + +impl PeersStore { + pub fn new(max_info_hashes: NonZeroUsize, max_peers: NonZeroUsize) -> Self { + Self { + rng: thread_rng(), + info_hashes: LruCache::new(max_info_hashes), + max_peers, + } + } + + pub fn add_peer(&mut self, info_hash: Id, peer: (&Id, SocketAddr)) { + if let Some(info_hash_lru) = self.info_hashes.get_mut(&info_hash) { + info_hash_lru.put(*peer.0, peer.1); + } else { + let mut info_hash_lru = LruCache::new(self.max_peers); + info_hash_lru.put(*peer.0, peer.1); + self.info_hashes.put(info_hash, info_hash_lru); + }; + } + + pub fn get_random_peers(&mut self, info_hash: &Id) -> Option> { + if let Some(info_hash_lru) = self.info_hashes.get(info_hash) { + let size = info_hash_lru.len(); + let target_size = 20; + + if size == 0 { + return None; + } else if size < target_size { + return Some( + info_hash_lru + .iter() + .map(|n| n.1.to_owned()) + .collect::>(), + ); + } + + let mut results = Vec::with_capacity(20); + + for (index, (_, addr)) in info_hash_lru.iter().enumerate() { + // Calculate the chance of adding the current item based on remaining items and slots + let remaining_slots = target_size - results.len(); + let remaining_items = info_hash_lru.len() - index; + let current_chance = remaining_slots as f64 / remaining_items as f64; + + // Randomly decide to add the item based on the current chance + if self.rng.gen_bool(current_chance) { + results.push(addr.to_owned()); + if results.len() == target_size { + break; + } + } + } + + return Some(results); + } + + None + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn max_info_hashes() { + let mut store = PeersStore::new( + NonZeroUsize::new(1).unwrap(), + NonZeroUsize::new(100).unwrap(), + ); + + let info_hash_a = Id::random(); + let info_hash_b = Id::random(); + + store.add_peer( + info_hash_a, + (&info_hash_a, SocketAddr::from(([127, 0, 1, 1], 0))), + ); + store.add_peer( + info_hash_b, + (&info_hash_b, SocketAddr::from(([127, 0, 1, 1], 0))), + ); + + assert_eq!(store.info_hashes.len(), 1); + assert_eq!( + store.get_random_peers(&info_hash_b), + Some(vec![SocketAddr::from(([127, 0, 1, 1], 0))]) + ); + } + + #[test] + fn all_peers() { + let mut store = + PeersStore::new(NonZeroUsize::new(1).unwrap(), NonZeroUsize::new(2).unwrap()); + + let info_hash_a = Id::random(); + let info_hash_b = Id::random(); + let info_hash_c = Id::random(); + + store.add_peer( + info_hash_a, + (&info_hash_a, SocketAddr::from(([127, 0, 1, 1], 0))), + ); + store.add_peer( + info_hash_a, + (&info_hash_b, SocketAddr::from(([127, 0, 1, 2], 0))), + ); + store.add_peer( + info_hash_a, + (&info_hash_c, SocketAddr::from(([127, 0, 1, 3], 0))), + ); + + assert_eq!( + store.get_random_peers(&info_hash_a), + Some(vec![ + SocketAddr::from(([127, 0, 1, 3], 0)), + SocketAddr::from(([127, 0, 1, 2], 0)), + ]) + ); + } + + #[test] + fn random_peers_subset() { + let mut store = PeersStore::new( + NonZeroUsize::new(1).unwrap(), + NonZeroUsize::new(200).unwrap(), + ); + + let info_hash = Id::random(); + + for i in (0..200) { + store.add_peer( + info_hash, + (&Id::random(), SocketAddr::from(([127, 0, 1, i], 0))), + ) + } + + assert_eq!(store.info_hashes.get(&info_hash).unwrap().len(), 200); + + let sample = store.get_random_peers(&info_hash).unwrap(); + + assert_eq!(sample.len(), 20); + } +} diff --git a/src/rpc/server/request.rs b/src/rpc/server/request.rs new file mode 100644 index 00000000..a10a29ab --- /dev/null +++ b/src/rpc/server/request.rs @@ -0,0 +1,291 @@ +//! Request hanlders + +use std::net::SocketAddr; + +use tracing::debug; + +use crate::common::{validate_immutable, Id, MutableItem}; +use crate::messages::{ + AnnouncePeerRequestArguments, ErrorSpecific, FindNodeRequestArguments, + FindNodeResponseArguments, GetImmutableResponseArguments, GetMutableResponseArguments, + GetPeersRequestArguments, GetPeersResponseArguments, GetValueRequestArguments, + NoValuesResponseArguments, PingResponseArguments, PutImmutableRequestArguments, + PutMutableRequestArguments, RequestSpecific, ResponseSpecific, +}; + +use super::super::Rpc; + +pub fn handle_request( + rpc: &mut Rpc, + from: SocketAddr, + transaction_id: u16, + request: &RequestSpecific, +) { + match request { + RequestSpecific::Ping(_) => { + rpc.socket.response( + from, + transaction_id, + ResponseSpecific::Ping(PingResponseArguments { + responder_id: rpc.id, + }), + ); + } + RequestSpecific::FindNode(FindNodeRequestArguments { target, .. }) => { + rpc.socket.response( + from, + transaction_id, + ResponseSpecific::FindNode(FindNodeResponseArguments { + responder_id: rpc.id, + nodes: rpc.routing_table.closest(target), + }), + ); + } + RequestSpecific::GetPeers(GetPeersRequestArguments { info_hash, .. }) => { + rpc.socket.response( + from, + transaction_id, + match rpc.peers.get_random_peers(info_hash) { + Some(peers) => ResponseSpecific::GetPeers(GetPeersResponseArguments { + responder_id: rpc.id, + token: rpc.tokens.generate_token(from).into(), + nodes: Some(rpc.routing_table.closest(info_hash)), + values: peers, + }), + None => ResponseSpecific::NoValues(NoValuesResponseArguments { + responder_id: rpc.id, + token: rpc.tokens.generate_token(from).into(), + nodes: Some(rpc.routing_table.closest(info_hash)), + }), + }, + ); + } + RequestSpecific::AnnouncePeer(AnnouncePeerRequestArguments { + info_hash, + port, + implied_port, + token, + requester_id, + .. + }) => { + if rpc.tokens.validate(from, token) { + let peer = match implied_port { + Some(true) => from, + _ => SocketAddr::new(from.ip(), *port), + }; + + rpc.peers.add_peer(*info_hash, (requester_id, peer)); + + rpc.socket.response( + from, + transaction_id, + ResponseSpecific::Ping(PingResponseArguments { + responder_id: rpc.id, + }), + ); + } else { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 203, + description: "Bad token".to_string(), + }, + ); + debug!(?from, ?token, "Invalid token"); + } + } + RequestSpecific::PutImmutable(PutImmutableRequestArguments { v, target, .. }) => { + if v.len() > 1000 { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 205, + description: "Message (v field) too big.".to_string(), + }, + ); + return; + } + if !validate_immutable(v, target) { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 203, + description: "Target doesn't match the sha1 hash of v field".to_string(), + }, + ); + return; + } + + rpc.immutable_values.put(*target, v.to_owned().into()); + + rpc.socket.response( + from, + transaction_id, + ResponseSpecific::Ping(PingResponseArguments { + responder_id: rpc.id, + }), + ); + } + RequestSpecific::PutMutable(PutMutableRequestArguments { + target, + v, + k, + seq, + sig, + salt, + cas, + .. + }) => { + if v.len() > 1000 { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 205, + description: "Message (v field) too big.".to_string(), + }, + ); + return; + } + if let Some(salt) = salt { + if salt.len() > 64 { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 207, + description: "salt (salt field) too big.".to_string(), + }, + ); + return; + } + } + if let Some(previous) = rpc.mutable_values.get(target) { + if let Some(cas) = cas { + if previous.seq() != cas { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 301, + description: "CAS mismatched, re-read value and try again." + .to_string(), + }, + ); + + return; + } + }; + + if seq <= previous.seq() { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 302, + description: "Sequence number less than current.".to_string(), + }, + ); + + return; + } + } + + match MutableItem::from_dht_message( + target, + k, + v.to_owned().into(), + seq, + sig, + &salt.to_owned().map(|v| v.into()), + cas, + ) { + Ok(item) => { + rpc.mutable_values.put(*target, item); + + rpc.socket.response( + from, + transaction_id, + ResponseSpecific::Ping(PingResponseArguments { + responder_id: rpc.id, + }), + ); + } + Err(_) => { + rpc.socket.error( + from, + transaction_id, + ErrorSpecific { + code: 206, + description: "Invalid signature".to_string(), + }, + ); + } + } + } + RequestSpecific::GetValue(GetValueRequestArguments { + requester_id, + target, + seq, + .. + }) => { + if seq.is_some() { + return handle_get_mutable(rpc, from, transaction_id, requester_id, target, seq); + } + + if let Some(v) = rpc.immutable_values.get(target) { + rpc.socket.response( + from, + transaction_id, + ResponseSpecific::GetImmutable(GetImmutableResponseArguments { + responder_id: rpc.id, + token: rpc.tokens.generate_token(from).into(), + nodes: Some(rpc.routing_table.closest(target)), + v: v.to_vec(), + }), + ) + } else { + handle_get_mutable(rpc, from, transaction_id, requester_id, target, seq); + }; + } + } +} + +fn handle_get_mutable( + rpc: &mut Rpc, + from: SocketAddr, + transaction_id: u16, + _requester_id: &Id, + target: &Id, + _seq: &Option, +) { + rpc.socket.response( + from, + transaction_id, + match rpc.mutable_values.get(target) { + Some(item) => { + // TODO: support seq (NoMoreRecentValue) + // if let Some(seq) = seq { + // } + + ResponseSpecific::GetMutable(GetMutableResponseArguments { + responder_id: rpc.id, + token: rpc.tokens.generate_token(from).into(), + nodes: Some(rpc.routing_table.closest(target)), + v: item.value().to_vec(), + k: item.key().to_vec(), + seq: *item.seq(), + sig: item.signature().to_vec(), + }) + } + None => ResponseSpecific::NoValues(NoValuesResponseArguments { + responder_id: rpc.id, + token: rpc.tokens.generate_token(from).into(), + nodes: Some(rpc.routing_table.closest(target)), + }), + }, + ) +} diff --git a/src/tokens.rs b/src/rpc/server/tokens.rs similarity index 97% rename from src/tokens.rs rename to src/rpc/server/tokens.rs index cfd4d943..0926a4e7 100644 --- a/src/tokens.rs +++ b/src/rpc/server/tokens.rs @@ -8,6 +8,8 @@ use std::{ time::{Duration, Instant}, }; +use tracing::trace; + const SECRET_SIZE: usize = 20; const TOKEN_SIZE: usize = 4; const ROTATE_INTERVAL: Duration = Duration::from_secs(60 * 5); @@ -56,6 +58,8 @@ impl Tokens { } pub fn rotate(&mut self) { + trace!("Rotating secrets"); + self.prev_secret = self.curr_secret; self.curr_secret = self.rng.gen(); diff --git a/src/socket.rs b/src/rpc/socket.rs similarity index 96% rename from src/socket.rs rename to src/rpc/socket.rs index 8a833e73..7866cf65 100644 --- a/src/socket.rs +++ b/src/rpc/socket.rs @@ -91,7 +91,9 @@ impl KrpcSocket { ); let tid = message.transaction_id; - let _ = self.send(address, message); + let _ = self.send(address, message).map_err(|e| { + debug!(?e, "Error sending request message"); + }); tid } @@ -105,18 +107,17 @@ impl KrpcSocket { ) { let message = self.response_message(MessageType::Response(response), address, transaction_id); - let _ = self.send(address, message); + let _ = self.send(address, message).map_err(|e| { + debug!(?e, "Error sending response message"); + }); } /// Send an error to the given address. - pub fn error( - &mut self, - address: SocketAddr, - transaction_id: u16, - error: ErrorSpecific, - ) -> Result<()> { + pub fn error(&mut self, address: SocketAddr, transaction_id: u16, error: ErrorSpecific) { let message = self.response_message(MessageType::Error(error), address, transaction_id); - self.send(address, message) + let _ = self.send(address, message).map_err(|e| { + debug!(?e, "Error sending error"); + }); } /// Receives a single krpc message on the socket.