Skip to content

Commit

Permalink
refactor; add simple block/allow rules
Browse files Browse the repository at this point in the history
  • Loading branch information
slivingston committed Oct 23, 2024
1 parent cbbae6c commit 8fdfe70
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ serde_json = "1.0"
serde_yaml = "0.8"
tempfile = "3.1"

tokio = { version = "1.40", features = ["net", "rt", "signal"] }
tokio = { version = "1.40", features = ["net", "rt", "signal", "sync"] }

[target.'cfg(target_os="linux")'.dependencies]
v4l = { version = "0.14", features = ["v4l2"] }
Expand Down
89 changes: 65 additions & 24 deletions src/bin/rrhttp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (C) 2024 rerobots, Inc.

use std::fmt::Write;
use std::sync::Arc;

#[macro_use]
extern crate clap;
Expand All @@ -15,9 +16,9 @@ use serde::Deserialize;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::runtime::Builder;
use tokio::{signal, time};
use tokio::{signal, sync::mpsc, time};

#[derive(Debug, Deserialize, PartialEq)]
#[derive(Clone, Debug, Deserialize, PartialEq)]
enum HttpVerb {
#[serde(alias = "GET")]
Get,
Expand Down Expand Up @@ -107,20 +108,20 @@ impl Request {
}
}

#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize, PartialEq)]
struct RequestRule {
verb: HttpVerb,
uri: String,
}

#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize, PartialEq)]
#[serde(rename_all = "kebab-case")]
enum ConfigMode {
Block,
Allow,
}

#[derive(Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize)]
struct Config {
default: ConfigMode,
rules: Vec<RequestRule>,
Expand All @@ -137,12 +138,42 @@ impl Config {
fn new_from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
Ok(serde_yaml::from_slice(&std::fs::read(path)?)?)
}

fn is_valid(&self, req: &Request) -> bool {
let mut matched = false;
for rule in self.rules.iter() {
if req.verb == rule.verb && req.uri == rule.uri {
if self.default == ConfigMode::Allow {
return false;
} else {
matched = true;
break;
}
}
}
(self.default == ConfigMode::Allow && !matched)
|| (self.default == ConfigMode::Block && matched)
}
}

async fn x_to_y_nofilter(
async fn writer_job(mut rx: mpsc::Receiver<Vec<u8>>, mut sink: tokio::net::tcp::OwnedWriteHalf) {
while let Some(blob) = rx.recv().await {
match sink.write(&blob).await {
Ok(n) => {
debug!("wrote {} bytes to ingress", n);
}
Err(err) => {
error!("while writing to ingress, error: {}", err);
return;
}
}
}
}

async fn filter_responses(
prefix: String,
mut x: tokio::net::tcp::OwnedReadHalf,
mut y: tokio::net::tcp::OwnedWriteHalf,
ingress_writer: mpsc::Sender<Vec<u8>>,
) {
let mut buf = [0; 1024];
loop {
Expand Down Expand Up @@ -171,24 +202,19 @@ async fn x_to_y_nofilter(
}
debug!("{}: raw: {}", prefix, raw);

match y.write(&buf[..n]).await {
Ok(n) => {
debug!("{}: wrote {} bytes", prefix, n);
}
Err(err) => {
error!("{}: error on write: {}", prefix, err);
return;
}
}
ingress_writer.send(buf[..n].to_vec()).await.unwrap();
}
}

async fn x_to_y_filter(
async fn filter_requests(
config: Arc<Config>,
prefix: String,
mut x: tokio::net::tcp::OwnedReadHalf,
mut y: tokio::net::tcp::OwnedWriteHalf,
ingress_writer: mpsc::Sender<Vec<u8>>,
) {
let mut buf = [0; 1024];
let forbidden_response = "HTTP/1.1 403 Forbidden\r\n\r\n".as_bytes();
loop {
let n = x.read(&mut buf).await.unwrap();
if n == 0 {
Expand All @@ -204,6 +230,14 @@ async fn x_to_y_filter(
}
};
debug!("parsed request: {:?}", req);
if !config.is_valid(&req) {
warn!("Request does not satisfy specification. Rejecting.");
ingress_writer
.send(forbidden_response.to_vec())
.await
.unwrap();
return;
}
match y.write(&buf[..n]).await {
Ok(n) => {
debug!("{}: wrote {} bytes", prefix, n);
Expand All @@ -216,7 +250,7 @@ async fn x_to_y_filter(
}
}

async fn main_per(ingress: TcpStream, egress: TcpStream) {
async fn main_per(config: Arc<Config>, ingress: TcpStream, egress: TcpStream) {
let ingress_peer_addr = ingress.peer_addr().unwrap();
let egress_peer_addr = egress.peer_addr().unwrap();
debug!(
Expand All @@ -225,22 +259,29 @@ async fn main_per(ingress: TcpStream, egress: TcpStream) {
);
let (ingress_read, ingress_write) = ingress.into_split();
let (egress_read, egress_write) = egress.into_split();
let in_to_e = tokio::spawn(x_to_y_filter(
let (tx, rx) = mpsc::channel(100);
let ingress_writer_task = tokio::spawn(writer_job(rx, ingress_write));
let in_to_e = tokio::spawn(filter_requests(
config,
format!("{} to {}", ingress_peer_addr, egress_peer_addr),
ingress_read,
egress_write,
tx.clone(),
));
let e_to_in = tokio::spawn(x_to_y_nofilter(
let e_to_in = tokio::spawn(filter_responses(
format!("{} to {}", egress_peer_addr, ingress_peer_addr),
egress_read,
ingress_write,
tx,
));
if let Err(err) = in_to_e.await {
error!("{:?}", err);
}
if let Err(err) = e_to_in.await {
error!("{:?}", err);
}
if let Err(err) = ingress_writer_task.await {
error!("{:?}", err)
}
debug!("done");
}

Expand All @@ -263,10 +304,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.version(crate_version!())
.get_matches();

let config = match matches.value_of("config") {
let config = Arc::new(match matches.value_of("config") {
Some(path) => Config::new_from_file(path)?,
None => Config::new(),
};
});

let targetaddr = String::from(matches.value_of("TARGET").unwrap());

Expand Down Expand Up @@ -312,7 +353,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
};

tokio::spawn(main_per(ingress, egress));
tokio::spawn(main_per(config.clone(), ingress, egress));
}
});

Expand Down

0 comments on commit 8fdfe70

Please sign in to comment.