diff --git a/Cargo.toml b/Cargo.toml index 7d1fca1..d0de396 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ serde_json = "1.0" serde_yaml = "0.8" tempfile = "3.1" -tokio = { version = "1.40", features = ["net", "rt", "signal"] } +tokio = { version = "1.40", features = ["net", "rt", "signal", "sync"] } [target.'cfg(target_os="linux")'.dependencies] v4l = { version = "0.14", features = ["v4l2"] } diff --git a/src/bin/rrhttp.rs b/src/bin/rrhttp.rs index 7980417..f79d402 100644 --- a/src/bin/rrhttp.rs +++ b/src/bin/rrhttp.rs @@ -1,6 +1,7 @@ // Copyright (C) 2024 rerobots, Inc. use std::fmt::Write; +use std::sync::Arc; #[macro_use] extern crate clap; @@ -15,9 +16,9 @@ use serde::Deserialize; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::runtime::Builder; -use tokio::{signal, time}; +use tokio::{signal, sync::mpsc, time}; -#[derive(Debug, Deserialize, PartialEq)] +#[derive(Clone, Debug, Deserialize, PartialEq)] enum HttpVerb { #[serde(alias = "GET")] Get, @@ -107,20 +108,20 @@ impl Request { } } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, PartialEq)] struct RequestRule { verb: HttpVerb, uri: String, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, PartialEq)] #[serde(rename_all = "kebab-case")] enum ConfigMode { Block, Allow, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] struct Config { default: ConfigMode, rules: Vec, @@ -137,12 +138,42 @@ impl Config { fn new_from_file(path: &str) -> Result> { Ok(serde_yaml::from_slice(&std::fs::read(path)?)?) } + + fn is_valid(&self, req: &Request) -> bool { + let mut matched = false; + for rule in self.rules.iter() { + if req.verb == rule.verb && req.uri == rule.uri { + if self.default == ConfigMode::Allow { + return false; + } else { + matched = true; + break; + } + } + } + (self.default == ConfigMode::Allow && !matched) + || (self.default == ConfigMode::Block && matched) + } } -async fn x_to_y_nofilter( +async fn writer_job(mut rx: mpsc::Receiver>, mut sink: tokio::net::tcp::OwnedWriteHalf) { + while let Some(blob) = rx.recv().await { + match sink.write(&blob).await { + Ok(n) => { + debug!("wrote {} bytes to ingress", n); + } + Err(err) => { + error!("while writing to ingress, error: {}", err); + return; + } + } + } +} + +async fn filter_responses( prefix: String, mut x: tokio::net::tcp::OwnedReadHalf, - mut y: tokio::net::tcp::OwnedWriteHalf, + ingress_writer: mpsc::Sender>, ) { let mut buf = [0; 1024]; loop { @@ -171,24 +202,19 @@ async fn x_to_y_nofilter( } debug!("{}: raw: {}", prefix, raw); - match y.write(&buf[..n]).await { - Ok(n) => { - debug!("{}: wrote {} bytes", prefix, n); - } - Err(err) => { - error!("{}: error on write: {}", prefix, err); - return; - } - } + ingress_writer.send(buf[..n].to_vec()).await.unwrap(); } } -async fn x_to_y_filter( +async fn filter_requests( + config: Arc, prefix: String, mut x: tokio::net::tcp::OwnedReadHalf, mut y: tokio::net::tcp::OwnedWriteHalf, + ingress_writer: mpsc::Sender>, ) { let mut buf = [0; 1024]; + let forbidden_response = "HTTP/1.1 403 Forbidden\r\n\r\n".as_bytes(); loop { let n = x.read(&mut buf).await.unwrap(); if n == 0 { @@ -204,6 +230,14 @@ async fn x_to_y_filter( } }; debug!("parsed request: {:?}", req); + if !config.is_valid(&req) { + warn!("Request does not satisfy specification. Rejecting."); + ingress_writer + .send(forbidden_response.to_vec()) + .await + .unwrap(); + return; + } match y.write(&buf[..n]).await { Ok(n) => { debug!("{}: wrote {} bytes", prefix, n); @@ -216,7 +250,7 @@ async fn x_to_y_filter( } } -async fn main_per(ingress: TcpStream, egress: TcpStream) { +async fn main_per(config: Arc, ingress: TcpStream, egress: TcpStream) { let ingress_peer_addr = ingress.peer_addr().unwrap(); let egress_peer_addr = egress.peer_addr().unwrap(); debug!( @@ -225,15 +259,19 @@ async fn main_per(ingress: TcpStream, egress: TcpStream) { ); let (ingress_read, ingress_write) = ingress.into_split(); let (egress_read, egress_write) = egress.into_split(); - let in_to_e = tokio::spawn(x_to_y_filter( + let (tx, rx) = mpsc::channel(100); + let ingress_writer_task = tokio::spawn(writer_job(rx, ingress_write)); + let in_to_e = tokio::spawn(filter_requests( + config, format!("{} to {}", ingress_peer_addr, egress_peer_addr), ingress_read, egress_write, + tx.clone(), )); - let e_to_in = tokio::spawn(x_to_y_nofilter( + let e_to_in = tokio::spawn(filter_responses( format!("{} to {}", egress_peer_addr, ingress_peer_addr), egress_read, - ingress_write, + tx, )); if let Err(err) = in_to_e.await { error!("{:?}", err); @@ -241,6 +279,9 @@ async fn main_per(ingress: TcpStream, egress: TcpStream) { if let Err(err) = e_to_in.await { error!("{:?}", err); } + if let Err(err) = ingress_writer_task.await { + error!("{:?}", err) + } debug!("done"); } @@ -263,10 +304,10 @@ fn main() -> Result<(), Box> { .version(crate_version!()) .get_matches(); - let config = match matches.value_of("config") { + let config = Arc::new(match matches.value_of("config") { Some(path) => Config::new_from_file(path)?, None => Config::new(), - }; + }); let targetaddr = String::from(matches.value_of("TARGET").unwrap()); @@ -312,7 +353,7 @@ fn main() -> Result<(), Box> { } }; - tokio::spawn(main_per(ingress, egress)); + tokio::spawn(main_per(config.clone(), ingress, egress)); } });