diff --git a/src/ml_selector.rs b/src/ml_selector.rs index a8d4501..171c5fa 100644 --- a/src/ml_selector.rs +++ b/src/ml_selector.rs @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; use std::io::BufWriter; use std::io::Write; use std::net::TcpStream; +use std::os::unix::net::UnixStream; use std::sync::Mutex; use std::time::Instant; use std::{ @@ -18,14 +19,14 @@ use std::{ }; lazy_static! { - static ref GLOBAL_TCP_STREAM: Mutex> = Mutex::new(None); + static ref GLOBAL_UNIX_STREAM: Mutex> = Mutex::new(None); } const END_MARKER: &[u8] = b""; -fn set_global_connection(address: &str) { - let stream = TcpStream::connect(address).expect("Failed to connect to the server"); - let mut global_stream = GLOBAL_TCP_STREAM.lock().unwrap(); +fn set_global_connection(socket_path: &str) { + let stream = UnixStream::connect(socket_path).expect("Failed to connect to the server"); + let mut global_stream = GLOBAL_UNIX_STREAM.lock().unwrap(); *global_stream = Some(stream); } @@ -73,7 +74,8 @@ struct Tuple(i64, i64); pub type MLDecomposer = HeuristicEliminationDecomposer; fn main() -> io::Result<()> { - set_global_connection("127.0.0.1:5001"); + let socket_path = "/tmp/server_socket"; + set_global_connection(socket_path); let file = File::create("output.csv")?; let buf_writer = BufWriter::new(file); @@ -184,7 +186,7 @@ fn main() -> io::Result<()> { Ok(()) } -fn read_until_marker(mut stream: &mut TcpStream) -> Vec { +fn read_until_marker(mut stream: &mut UnixStream) -> Vec { let mut reader = BufReader::new(&mut stream); let mut buffer = Vec::new(); let mut chunk = [0; 4096]; @@ -217,7 +219,7 @@ fn read_until_marker(mut stream: &mut TcpStream) -> Vec { } fn ml_values(graph: &HashMapGraph, cache: &mut [i64]) -> io::Result<()> { - let mut global_stream = GLOBAL_TCP_STREAM.lock().unwrap(); + let mut global_stream = GLOBAL_UNIX_STREAM.lock().unwrap(); if let Some(ref mut stream) = *global_stream { let mut serialized_graph = graph.serialize(); serialized_graph.extend_from_slice(END_MARKER);