Skip to content

Commit

Permalink
Switch to unix sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
maxivhuber committed Jan 20, 2025
1 parent a86a923 commit 2ebe1db
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/ml_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize};
use std::io::BufWriter;
use std::io::Write;
use std::net::TcpStream;

Check warning on line 11 in src/ml_selector.rs

View workflow job for this annotation

GitHub Actions / Check

unused import: `std::net::TcpStream`

Check warning on line 11 in src/ml_selector.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused import: `std::net::TcpStream`
use std::os::unix::net::UnixStream;
use std::sync::Mutex;
use std::time::Instant;
use std::{
Expand All @@ -18,14 +19,14 @@ use std::{
};

lazy_static! {
static ref GLOBAL_TCP_STREAM: Mutex<Option<TcpStream>> = Mutex::new(None);
static ref GLOBAL_UNIX_STREAM: Mutex<Option<UnixStream>> = Mutex::new(None);
}

const END_MARKER: &[u8] = b"<END>";

fn set_global_connection(address: &str) {
let stream = TcpStream::connect(address).expect("Failed to connect to the server");
let mut global_stream = GLOBAL_TCP_STREAM.lock().unwrap();
fn set_global_connection(socket_path: &str) {
let stream = UnixStream::connect(socket_path).expect("Failed to connect to the server");
let mut global_stream = GLOBAL_UNIX_STREAM.lock().unwrap();
*global_stream = Some(stream);
}

Expand Down Expand Up @@ -73,7 +74,8 @@ struct Tuple(i64, i64);
pub type MLDecomposer = HeuristicEliminationDecomposer<MLSelector>;

fn main() -> io::Result<()> {
set_global_connection("127.0.0.1:5001");
let socket_path = "/tmp/server_socket";
set_global_connection(socket_path);

let file = File::create("output.csv")?;
let buf_writer = BufWriter::new(file);
Expand Down Expand Up @@ -184,7 +186,7 @@ fn main() -> io::Result<()> {
Ok(())
}

fn read_until_marker(mut stream: &mut TcpStream) -> Vec<u8> {
fn read_until_marker(mut stream: &mut UnixStream) -> Vec<u8> {
let mut reader = BufReader::new(&mut stream);
let mut buffer = Vec::new();
let mut chunk = [0; 4096];
Expand Down Expand Up @@ -217,7 +219,7 @@ fn read_until_marker(mut stream: &mut TcpStream) -> Vec<u8> {
}

fn ml_values(graph: &HashMapGraph, cache: &mut [i64]) -> io::Result<()> {
let mut global_stream = GLOBAL_TCP_STREAM.lock().unwrap();
let mut global_stream = GLOBAL_UNIX_STREAM.lock().unwrap();
if let Some(ref mut stream) = *global_stream {
let mut serialized_graph = graph.serialize();
serialized_graph.extend_from_slice(END_MARKER);
Expand Down

0 comments on commit 2ebe1db

Please sign in to comment.