diff --git a/Cargo.lock b/Cargo.lock index 47bf4d4a4..bf55e740b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -831,7 +831,6 @@ dependencies = [ "shadowsocks", "smoltcp", "socket2 0.5.5", - "state", "tempfile", "thiserror", "tokio", @@ -1572,19 +1571,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "generator" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc16584ff22b460a382b7feec54b23d2908d858152e5739a120b949293bd74e" -dependencies = [ - "cc", - "libc", - "log", - "rustversion", - "windows 0.48.0", -] - [[package]] name = "generic-array" version = "0.14.7" @@ -2300,21 +2286,6 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" -[[package]] -name = "loom" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff50ecb28bb86013e935fb6683ab1f6d3a20016f123c76fd4c27470076ac30f5" -dependencies = [ - "cfg-if", - "generator", - "scoped-tls", - "serde", - "serde_json", - "tracing", - "tracing-subscriber", -] - [[package]] name = "lru-cache" version = "0.1.2" @@ -3340,12 +3311,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - [[package]] name = "scopeguard" version = "1.2.0" @@ -3673,15 +3638,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "state" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b8c4a4445d81357df8b1a650d0d0d6fbbbfe99d064aa5e02f3e4022061476d8" -dependencies = [ - "loom", -] - [[package]] name = "strsim" version = "0.8.0" diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index 88b8d0ed8..40c5776d3 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -23,7 +23,6 @@ ipnet = "2.9" url = "2.5" regex = "1" byteorder = "1.5" -state = "0.6" lru_time_cache = "0.11" hyper = { version = "0.14", features = ["http1","http2","client", "server", "tcp"] } http = { version = "0.2.11" } diff --git a/clash_lib/src/app/api/handlers/config.rs b/clash_lib/src/app/api/handlers/config.rs index 82169ff4a..f4f33d5e8 100644 --- a/clash_lib/src/app/api/handlers/config.rs +++ b/clash_lib/src/app/api/handlers/config.rs @@ -1,6 +1,12 @@ -use std::sync::Arc; +use std::{path::PathBuf, sync::Arc}; + +use axum::{ + extract::{Query, State}, + response::IntoResponse, + routing::get, + Json, Router, +}; -use axum::{extract::State, response::IntoResponse, routing::get, Json, Router}; use http::StatusCode; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex; @@ -52,7 +58,7 @@ async fn get_configs(State(state): State) -> impl IntoResponse { let ports = inbound_manager.get_ports(); - axum::response::Json(ConfigRequest { + axum::response::Json(PatchConfigRequest { port: ports.port, socks_port: ports.socks_port, redir_port: ports.redir_port, @@ -73,16 +79,71 @@ async fn get_configs(State(state): State) -> impl IntoResponse { }) } -async fn update_configs() -> impl IntoResponse { - ( - StatusCode::NOT_IMPLEMENTED, - axum::response::Json("don't do this please"), - ) +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +struct UpdateConfigRequest { + path: Option, + payload: Option, +} + +#[derive(Serialize, Deserialize)] +struct UploadConfigQuery { + force: Option, +} + +async fn update_configs( + _q: Query, + State(state): State, + Json(req): Json, +) -> impl IntoResponse { + let g = state.global_state.lock().await; + match (req.path, req.payload) { + (_, Some(payload)) => { + let msg = format!("config reloading from payload"); + let cfg = crate::Config::Str(payload); + match g.reload_tx.send(cfg).await { + Ok(_) => (StatusCode::ACCEPTED, msg).into_response(), + Err(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + "could not signal config reload", + ) + .into_response(), + } + } + (Some(mut path), None) => { + if !PathBuf::from(&path).is_absolute() { + path = PathBuf::from(g.cwd.clone()) + .join(path) + .to_string_lossy() + .to_string(); + } + if !PathBuf::from(&path).exists() { + return ( + StatusCode::BAD_REQUEST, + format!("config file {} not found", path), + ) + .into_response(); + } + + let msg = format!("config reloading from file {}", path); + let cfg: crate::Config = crate::Config::File(path); + match g.reload_tx.send(cfg).await { + Ok(_) => (StatusCode::ACCEPTED, msg).into_response(), + + Err(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + "could not signal config reload", + ) + .into_response(), + } + } + (None, None) => (StatusCode::BAD_REQUEST, "no path or payload provided").into_response(), + } } #[derive(Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] -struct ConfigRequest { +struct PatchConfigRequest { port: Option, socks_port: Option, redir_port: Option, @@ -95,7 +156,7 @@ struct ConfigRequest { allow_lan: Option, } -impl ConfigRequest { +impl PatchConfigRequest { fn rebuild_listeners(&self) -> bool { self.port.is_some() || self.socks_port.is_some() @@ -108,7 +169,7 @@ impl ConfigRequest { async fn patch_configs( State(state): State, - Json(payload): Json, + Json(payload): Json, ) -> impl IntoResponse { if payload.allow_lan.is_some() { warn!("setting allow_lan doesn't do anything. please set bind_address to a LAN address instead."); diff --git a/clash_lib/src/common/mmdb.rs b/clash_lib/src/common/mmdb.rs index 098e62ea6..c5b38b099 100644 --- a/clash_lib/src/common/mmdb.rs +++ b/clash_lib/src/common/mmdb.rs @@ -24,7 +24,15 @@ impl MMDB { http_client: HttpClient, ) -> Result { debug!("mmdb path: {}", path.as_ref().to_string_lossy()); + let reader = Self::load_mmdb(path, download_url, &http_client).await?; + Ok(Self { reader }) + } + async fn load_mmdb>( + path: P, + download_url: Option, + http_client: &HttpClient, + ) -> Result>, Error> { let mmdb_file = path.as_ref().to_path_buf(); if !mmdb_file.exists() { @@ -43,7 +51,7 @@ impl MMDB { } match maxminddb::Reader::open_readfile(&path) { - Ok(r) => Ok(MMDB { reader: r }), + Ok(r) => Ok(r), Err(e) => match e { maxminddb::MaxMindDBError::InvalidDatabaseError(_) | maxminddb::MaxMindDBError::IoError(_) => { @@ -62,15 +70,13 @@ impl MMDB { .map_err(|x| { Error::InvalidConfig(format!("mmdb download failed: {}", x)) })?; - Ok(MMDB { - reader: maxminddb::Reader::open_readfile(&path).map_err(|x| { - Error::InvalidConfig(format!( - "cant open mmdb `{}`: {}", - path.as_ref().to_string_lossy(), - x.to_string() - )) - })?, - }) + Ok(maxminddb::Reader::open_readfile(&path).map_err(|x| { + Error::InvalidConfig(format!( + "cant open mmdb `{}`: {}", + path.as_ref().to_string_lossy(), + x.to_string() + )) + })?) } else { return Err(Error::InvalidConfig(format!( "mmdb `{}` not found and mmdb_download_url is not set", @@ -89,8 +95,8 @@ impl MMDB { } } - #[async_recursion(?Send)] - async fn download>( + #[async_recursion] + async fn download + std::marker::Send>( url: &str, path: P, http_client: &HttpClient, @@ -129,10 +135,9 @@ impl MMDB { Ok(()) } - pub fn lookup(&self, ip: IpAddr) -> anyhow::Result { + pub fn lookup(&self, ip: IpAddr) -> std::io::Result { self.reader - .lookup(ip) + .lookup::(ip) .map_err(map_io_error) - .map_err(|x| x.into()) } } diff --git a/clash_lib/src/lib.rs b/clash_lib/src/lib.rs index a6f412ca9..f5b479ccd 100644 --- a/clash_lib/src/lib.rs +++ b/clash_lib/src/lib.rs @@ -18,10 +18,9 @@ use common::mmdb; use config::def::LogLevel; use proxy::tun::get_tun_runner; -use state::InitCell; use std::io; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use thiserror::Error; use tokio::sync::{broadcast, mpsc, Mutex}; use tokio::task::JoinHandle; @@ -93,15 +92,18 @@ impl Config { pub struct GlobalState { log_level: LogLevel, inbound_listener_handle: Option>>, - #[allow(dead_code)] + tunnel_listener_handle: Option>>, + api_listener_handle: Option>>, dns_listener_handle: Option>>, + reload_tx: mpsc::Sender, + cwd: String, } pub struct RuntimeController { shutdown_tx: mpsc::Sender<()>, } -static RUNTIME_CONTROLLER: InitCell> = InitCell::new(); +static RUNTIME_CONTROLLER: OnceLock> = OnceLock::new(); pub fn start(opts: Options) -> Result<(), Error> { let rt = match opts.rt.as_ref().unwrap_or(&TokioRuntime::MultiThread) { @@ -125,7 +127,7 @@ pub fn start(opts: Options) -> Result<(), Error> { } pub fn shutdown() -> bool { - match RUNTIME_CONTROLLER.get().write() { + match RUNTIME_CONTROLLER.get().unwrap().write() { Ok(rt) => rt.shutdown_tx.blocking_send(()).is_ok(), _ => false, } @@ -134,12 +136,11 @@ pub fn shutdown() -> bool { async fn start_async(opts: Options) -> Result<(), Error> { let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1); - RUNTIME_CONTROLLER.set(std::sync::RwLock::new(RuntimeController { shutdown_tx })); + let _ = RUNTIME_CONTROLLER.set(std::sync::RwLock::new(RuntimeController { shutdown_tx })); let config: InternalConfig = opts.config.try_parse()?; let cwd = opts.cwd.unwrap_or_else(|| ".".to_string()); - let cwd = std::path::Path::new(&cwd); let (log_tx, _) = broadcast::channel(100); @@ -148,7 +149,7 @@ async fn start_async(opts: Options) -> Result<(), Error> { let _g = app::logging::setup_logging( config.general.log_level, log_collector, - cwd.to_str().unwrap(), + &cwd, opts.log_file, ) .map_err(|x| Error::InvalidConfig(format!("failed to setup logging: {}", x.to_string())))?; @@ -168,6 +169,7 @@ async fn start_async(opts: Options) -> Result<(), Error> { let client = new_http_client(system_resolver).map_err(|x| Error::DNSError(x.to_string()))?; debug!("initializing mmdb"); + let cwd = PathBuf::from(cwd); let mmdb = Arc::new( mmdb::MMDB::new( cwd.join(&config.general.mmdb), @@ -248,27 +250,35 @@ async fn start_async(opts: Options) -> Result<(), Error> { let inbound_listener_handle = tokio::spawn(inbound_runner); let tun_runner = get_tun_runner(config.tun, dispatcher.clone(), dns_resolver.clone())?; - if let Some(tun_runner) = tun_runner { - runners.push(tun_runner); - } + let tun_runner_handle = if let Some(tun_runner) = tun_runner { + Some(tokio::spawn(tun_runner)) + } else { + None + }; debug!("initializing dns listener"); let dns_listener_handle = dns::get_dns_listener(config.dns, dns_resolver.clone()) .await .map(|l| tokio::spawn(l)); + let (reload_tx, mut reload_rx) = mpsc::channel(1); + let global_state = Arc::new(Mutex::new(GlobalState { log_level: config.general.log_level, inbound_listener_handle: Some(inbound_listener_handle), + tunnel_listener_handle: tun_runner_handle, dns_listener_handle, + reload_tx, + api_listener_handle: None, + cwd: cwd.to_string_lossy().to_string(), })); let api_runner = app::api::get_api_runner( config.general.controller, log_tx, - inbound_manager, + inbound_manager.clone(), dispatcher, - global_state, + global_state.clone(), dns_resolver, outbound_manager, statistics_manager, @@ -277,7 +287,8 @@ async fn start_async(opts: Options) -> Result<(), Error> { cwd.to_string_lossy().to_string(), ); if let Some(r) = api_runner { - runners.push(r); + let api_listener_handle = tokio::spawn(r); + global_state.lock().await.api_listener_handle = Some(api_listener_handle); } runners.push(Box::pin(async move { @@ -295,6 +306,134 @@ async fn start_async(opts: Options) -> Result<(), Error> { Ok(()) })); + tasks.push(Box::pin(async move { + while let Some(config) = reload_rx.recv().await { + info!("reloading config"); + let config = match config.try_parse() { + Ok(c) => c, + Err(e) => { + error!("failed to reload config: {}", e); + continue; + } + }; + + debug!("reloading dns resolver"); + let system_resolver = + Arc::new(SystemResolver::new().map_err(|x| Error::DNSError(x.to_string()))?); + let client = + new_http_client(system_resolver).map_err(|x| Error::DNSError(x.to_string()))?; + + debug!("reloading mmdb"); + let mmdb = Arc::new( + mmdb::MMDB::new( + cwd.join(&config.general.mmdb), + config.general.mmdb_download_url, + client, + ) + .await?, + ); + + debug!("reloading cache store"); + let cache_store = profile::ThreadSafeCacheFile::new( + cwd.join("cache.db").as_path().to_str().unwrap(), + config.profile.store_selected, + ); + + let dns_resolver = + dns::Resolver::new(&config.dns, cache_store.clone(), mmdb.clone()).await; + + debug!("reloading outbound manager"); + let outbound_manager = Arc::new( + OutboundManager::new( + config + .proxies + .into_values() + .filter_map(|x| match x { + OutboundProxy::ProxyServer(s) => Some(s), + _ => None, + }) + .collect(), + config + .proxy_groups + .into_values() + .filter_map(|x| match x { + OutboundProxy::ProxyGroup(g) => Some(g), + _ => None, + }) + .collect(), + config.proxy_providers, + config.proxy_names, + dns_resolver.clone(), + cache_store.clone(), + cwd.to_string_lossy().to_string(), + ) + .await?, + ); + + debug!("reloading router"); + let router = Arc::new( + Router::new( + config.rules, + config.rule_providers, + dns_resolver.clone(), + mmdb, + cwd.to_string_lossy().to_string(), + ) + .await, + ); + + let statistics_manager = StatisticsManager::new(); + + let dispatcher = Arc::new(Dispatcher::new( + outbound_manager.clone(), + router.clone(), + dns_resolver.clone(), + config.general.mode, + statistics_manager.clone(), + )); + + let authenticator = Arc::new(auth::PlainAuthenticator::new(config.users)); + + debug!("reloading inbound manager"); + let inbound_manager = Arc::new(Mutex::new(InboundManager::new( + config.general.inbound, + dispatcher.clone(), + authenticator, + )?)); + + let mut g = global_state.lock().await; + if let Some(h) = g.inbound_listener_handle.take() { + h.abort(); + } + if let Some(h) = g.tunnel_listener_handle.take() { + h.abort(); + } + if let Some(h) = g.dns_listener_handle.take() { + h.abort(); + } + + let inbound_runner = inbound_manager.lock().await.get_runner()?; + let inbound_listener_handle = tokio::spawn(inbound_runner); + + let tun_runner = get_tun_runner(config.tun, dispatcher.clone(), dns_resolver.clone())?; + let tun_runner_handle = if let Some(tun_runner) = tun_runner { + Some(tokio::spawn(tun_runner)) + } else { + None + }; + + debug!("initializing dns listener"); + let dns_listener_handle = dns::get_dns_listener(config.dns, dns_resolver.clone()) + .await + .map(|l| tokio::spawn(l)); + + g.inbound_listener_handle = Some(inbound_listener_handle); + g.tunnel_listener_handle = tun_runner_handle; + g.dns_listener_handle = dns_listener_handle; + } + Ok(()) + })); + futures::future::select_all(tasks).await.0.map_err(|x| { error!("runtime error: {}, shutting down", x); x @@ -302,6 +441,7 @@ async fn start_async(opts: Options) -> Result<(), Error> { } #[cfg(test)] +#[allow(non_snake_case)] #[ctor::ctor] fn setup_tests() { println!("setup tests");