diff --git a/.gitignore b/.gitignore index ea8c4bf..96ef6c0 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +Cargo.lock diff --git a/.travis.yml b/.travis.yml index 589f77f..2320e28 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,14 +4,19 @@ dist: xenial matrix: include: - # Stable channel. + # Build on linux - os: linux - rust: stable - env: TARGET=x86_64-unknown-linux-gnu + rust: nightly + env: + - TARGET=x86_64-unknown-linux-gnu + - DEPLOY=true + # Build on osx - os: osx - rust: stable - env: TARGET=x86_64-apple-darwin + rust: nightly + env: + - TARGET=x86_64-apple-darwin + - DEPLOY=true # Code formatting check - os: linux @@ -21,8 +26,9 @@ matrix: - cargo install --debug --force rustfmt-nightly script: cargo fmt -- --check + # Test benchmark - os: linux - rust: stable + rust: nightly install: cargo build script: bash ci/benchmark.bash @@ -75,7 +81,7 @@ deploy: # deploy only if we push a tag tags: true # deploy only on stable channel that has TARGET env variable sets - condition: $TRAVIS_RUST_VERSION = stable && $TARGET != "" + condition: $DEPLOY = true && $TARGET != "" notifications: email: diff --git a/Cargo.toml b/Cargo.toml index 018bc5a..b8f5de3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,19 +19,18 @@ edition = "2018" name = "ag" path = "src/main.rs" +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + [dependencies] -percent-encoding = "2.1" -term_size = "0.3" -ansi_term = "0.12" bytes = "0.5" -failure = "0.1" - +isahc = { version = "0.9", features = ["cookies"] } futures = "0.3" -futures-util = "0.3" -actix-rt = "1.0" -awc = { version = "1.0", features = ["openssl"] } - -[dependencies.clap] -version = "2.33" -default-features = false -features = [ "suggestions", "color" ] +async-std = { version = "1", features = ["unstable"] } +thiserror = "1.0" +term_size = "0.3" +ansi_term = "0.12" +openssl = "0.10" +clap = "2" +percent-encoding = "2" +m3u8-rs = "1" +url = "2" diff --git a/src/app.rs b/src/app.rs deleted file mode 100644 index 62ae172..0000000 --- a/src/app.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{env, path::Path}; - -#[cfg(windows)] -use ansi_term::enable_ansi_support; - -use clap::ArgMatches; - -use percent_encoding::percent_decode; - -use awc::http::Uri; - -use crate::{ - clap_app::build_app, - common::AGET_EXT, - error::{ArgError, Result}, - util::{escape_nonascii, LiteralSize}, -}; - -#[derive(Debug, Clone)] -pub struct Config { - pub(crate) uri: String, - pub(crate) method: String, - pub(crate) headers: Vec, - pub(crate) data: Option, - pub(crate) path: String, - pub(crate) concurrency: u64, - pub(crate) chunk_length: u64, - pub(crate) timeout: u64, - pub(crate) max_retries: u32, - pub(crate) retry_wait: u64, - pub(crate) debug: bool, - pub(crate) quiet: bool, -} - -impl Config { - pub fn new( - uri: String, - method: String, - headers: Vec, - data: Option, - path: String, - concurrency: u64, - timeout: u64, - chunk_length: u64, - max_retries: u32, - retry_wait: u64, - debug: bool, - quiet: bool, - ) -> Config { - Config { - uri, - method, - headers, - data, - path, - concurrency, - timeout, - chunk_length, - max_retries, - retry_wait, - debug, - quiet, - } - } -} - -pub struct App { - pub matches: ArgMatches<'static>, -} - -impl App { - pub fn new() -> App { - #[cfg(windows)] - let _ = enable_ansi_support(); - - App { - matches: Self::matches(), - } - } - - fn matches() -> ArgMatches<'static> { - let args = env::args(); - let matches = build_app().get_matches_from(args); - matches - } - - pub fn config(&self) -> Result { - // uri - let uri = self.matches.value_of("URL").map(escape_nonascii).unwrap(); - - // path - let path = if let Some(path) = self.matches.value_of("out") { - path.to_string() - } else { - let uri = &uri.parse::()?; - let path = Path::new(uri.path()); - if let Some(file_name) = path.file_name() { - percent_decode(file_name.to_str().unwrap().as_bytes()) - .decode_utf8() - .unwrap() - .to_string() - } else { - return Err(ArgError::NoFilename); - } - }; - - // check status of task - let path_ = Path::new(&path); - let mut file_name = path_.file_name().unwrap().to_os_string(); - file_name.push(AGET_EXT); - let mut aget_path = path_.to_path_buf(); - aget_path.set_file_name(file_name); - if path_.is_dir() { - return Err(ArgError::PathIsDirectory); - } - if path_.exists() && !aget_path.as_path().exists() { - return Err(ArgError::FileExists); - } - - let path = path.to_string(); - - // data - let data = if let Some(data) = self.matches.value_of("data") { - Some(data.to_string()) - } else { - None - }; - - // method - let method = if let Some(method) = self.matches.value_of("method") { - method.to_string() - } else { - if data.is_some() { - "POST".to_owned() - } else { - "GET".to_owned() - } - }; - - // headers - let headers = if let Some(headers) = self.matches.values_of("header") { - headers.map(String::from).collect::>() - } else { - Vec::new() - }; - - // concurrency - let concurrency = if let Some(concurrency) = self.matches.value_of("concurrency") { - concurrency.parse::()? - } else { - 10 - }; - - // chunk length - let chunk_length = if let Some(chunk_length) = self.matches.value_of("chunk-length") { - chunk_length.literal_size()? - } else { - 1024 * 500 // 500k - }; - - // timeout - let timeout = if let Some(timeout) = self.matches.value_of("timeout") { - timeout.parse::()? - } else { - 10 - }; - - // maximum retries - let max_retries = if let Some(max_retries) = self.matches.value_of("max_retries") { - max_retries.parse::()? - } else { - 5 - }; - - let retry_wait = if let Some(retry_wait) = self.matches.value_of("retry_wait") { - retry_wait.parse::()? - } else { - 5 - }; - - let debug = self.matches.is_present("debug"); - - let quiet = self.matches.is_present("quiet"); - - Ok(Config::new( - uri, - method, - headers, - data, - path, - concurrency, - timeout, - chunk_length, - max_retries, - retry_wait, - debug, - quiet, - )) - } -} - -fn test_escape_nonascii() { - let s = ":ss/s 来;】/ 【【 ? 是的 & 水电费=45 进来看"; - println!("{}", s); - println!("{}", escape_nonascii(s)); -} diff --git a/src/app/core/http.rs b/src/app/core/http.rs new file mode 100644 index 0000000..1cf83fb --- /dev/null +++ b/src/app/core/http.rs @@ -0,0 +1,397 @@ +use std::{path::PathBuf, sync::Arc}; + +use async_std::{io::ReadExt, process::exit, task as std_task}; + +use futures::{ + channel::mpsc::{channel, Sender}, + SinkExt, +}; + +use crate::{ + app::{ + receive::http_receiver::HttpReceiver, + stats::range_stats::{RangeStats, RANGESTATS_FILE_SUFFIX}, + }, + common::{ + buf::SIZE, + bytes::bytes_type::{Buf, Bytes, BytesMut}, + errors::{Error, Result}, + file::File, + net::{ + net::{build_http_client, content_length, redirect, request}, + net_type::{ContentLengthValue, HttpClient, Method, Uri}, + }, + range::{split_pair, RangePair, SharedRangList}, + }, + features::{args::Args, running::Runnable, stack::StackLike}, +}; + +/// Http task handler +pub struct HttpHandler { + output: PathBuf, + method: Method, + uri: Uri, + headers: Vec<(String, String)>, + data: Option, + timeout: u64, + concurrency: u64, + chunk_size: u64, + retries: u64, + retry_wait: u64, + proxy: Option, + client: Arc, +} + +impl HttpHandler { + pub fn new(args: &impl Args) -> Result { + let headers = args.headers(); + let timeout = args.timeout(); + let proxy = args.proxy(); + + let hds: Vec<(&str, &str)> = headers + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + let client = build_http_client(hds.as_ref(), timeout, proxy.as_deref())?; + + debug!("HttpHandler::new"); + + Ok(HttpHandler { + output: args.output(), + method: args.method(), + uri: args.uri(), + headers, + data: args.data().map(|ref mut d| d.to_bytes()), + timeout, + concurrency: args.concurrency(), + chunk_size: args.chunk_length(), + retries: args.retries(), + retry_wait: args.retry_wait(), + proxy, + client: Arc::new(client), + }) + } + + async fn start(&mut self) -> Result<()> { + debug!("HttpHandler::start"); + + // 0. Check whether task is completed + debug!("HttpHandler: check whether task is completed"); + let mut rangestats = + RangeStats::new(&*(self.output.to_string_lossy() + RANGESTATS_FILE_SUFFIX))?; + if self.output.exists() && !rangestats.exists() { + return Ok(()); + } + + // 1. Redirect + debug!("HttpHandler: redirect start"); + let uri = redirect( + &self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + ) + .await?; + debug!("HttpHandler: redirect to", uri); + self.uri = uri; + + // 2. get content_length + debug!("HttpHandler: content_length start"); + let cl = content_length( + &self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + ) + .await?; + debug!("HttpHandler: content_length", cl); + + // 3. Compare recorded content length with the above one + debug!("HttpHandler: compare recorded content length"); + let mut direct = true; + if let ContentLengthValue::RangeLength(cl) = cl { + if self.output.exists() { + if rangestats.exists() { + rangestats.open()?; + } else { + // Task is completed + return Ok(()); + } + } else { + // Init rangestats + rangestats.remove().unwrap_or(()); // Missing error + rangestats.open()?; + } + + let pre_cl = rangestats.total()?; + + // Inital rangestats + if pre_cl == 0 && pre_cl != cl { + rangestats.write_total(cl)?; + direct = false; + } + // Content is empty + else if pre_cl == 0 && pre_cl == cl { + File::new(&self.output, true)?.open()?; + rangestats.remove()?; + return Ok(()); + } + // Content length is not consistent + else if pre_cl != 0 && pre_cl != cl { + return Err(Error::ContentLengthIsNotConsistent); + } + // Rewrite statistic status + else if pre_cl != 0 && pre_cl == cl { + rangestats.rewrite()?; + direct = false; + } + } + + // 4. Create channel + let (sender, receiver) = channel::<(RangePair, Bytes)>(self.concurrency as usize + 10); + + // 5. Dispatch Task + debug!("HttpHandler: dispatch task: direct", direct); + if direct { + let mut task = DirectRequestTask::new( + self.client.clone(), + self.method.clone(), + self.uri.clone(), + self.data.clone(), + sender.clone(), + ); + std_task::spawn(async move { + task.start().await; + }); + } else { + // Make range pairs stack + let mut stack = vec![]; + let gaps = rangestats.gaps()?; + for gap in gaps.iter() { + let mut list = split_pair(gap, self.chunk_size); + stack.append(&mut list); + } + stack.reverse(); + let stack = SharedRangList::new(stack); + // let stack = SharedRangList::new(rangestats.gaps()?); + debug!("HttpHandler: range stack length", stack.len()); + + let concurrency = std::cmp::min(stack.len() as u64, self.concurrency); + for i in 1..concurrency + 1 { + let mut task = RangeRequestTask::new( + self.client.clone(), + self.method.clone(), + self.uri.clone(), + self.data.clone(), + stack.clone(), + sender.clone(), + i, + ); + std_task::spawn(async move { + task.start().await; + }); + } + } + drop(sender); // Remove the reference and let `Task` to handle it + + // 6. Create receiver + debug!("HttpHandler: create receiver"); + let mut httpreceiver = HttpReceiver::new(&self.output, direct)?; + httpreceiver.start(receiver).await?; + + // 7. Task succeeds. Remove rangestats file + rangestats.remove().unwrap_or(()); // Missing error + Ok(()) + } +} + +impl Runnable for HttpHandler { + fn run(&mut self) -> Result<()> { + std_task::block_on(self.start()) + } +} + +/// Directly request the resource without range header +struct DirectRequestTask { + client: Arc, + method: Method, + uri: Uri, + data: Option, + sender: Sender<(RangePair, Bytes)>, +} + +impl DirectRequestTask { + fn new( + client: Arc, + method: Method, + uri: Uri, + data: Option, + sender: Sender<(RangePair, Bytes)>, + ) -> DirectRequestTask { + DirectRequestTask { + client, + method, + uri, + data, + sender, + } + } + + async fn start(&mut self) { + loop { + let resp = request( + &*self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + None, + ) + .await; + if let Err(err) = resp { + print_err!("DirectRequestTask request error", err); + continue; + } + let resp = resp.unwrap(); + + let mut buf = [0; SIZE]; + let mut offset = 0; + let mut reader = resp.into_body(); + loop { + match reader.read(&mut buf).await { + Ok(0) => { + return; + } + Ok(len) => { + let pair = RangePair::new(offset, offset + len as u64 - 1); // The pair is a closed interval + let mut b = BytesMut::from(&buf[..len]); + if self.sender.send((pair, b.to_bytes())).await.is_err() { + break; + } + offset += len as u64; + } + Err(err) => { + print_err!("DirectRequestTask read error", err); + break; + } + } + } + } + } +} + +/// Request the resource with a range header which is in the `SharedRangList` +struct RangeRequestTask { + client: Arc, + method: Method, + uri: Uri, + data: Option, + stack: SharedRangList, + sender: Sender<(RangePair, Bytes)>, + id: u64, +} + +impl RangeRequestTask { + fn new( + client: Arc, + method: Method, + uri: Uri, + data: Option, + stack: SharedRangList, + sender: Sender<(RangePair, Bytes)>, + id: u64, + ) -> RangeRequestTask { + RangeRequestTask { + client, + method, + uri, + data, + stack, + sender, + id, + } + } + + async fn start(&mut self) { + while let Some(pair) = self.stack.pop() { + match self.req(pair).await { + // Exit whole process when `Error::InnerError` is returned + Err(Error::InnerError(msg)) => { + print_err!(format!("RangeRequestTask {}: InnerError", self.id), msg); + exit(1); + } + Err(err) => { + print_err!(format!("RangeRequestTask {}: error", self.id), err); + } + _ => {} + } + } + } + + async fn req(&mut self, pair: RangePair) -> Result<()> { + let resp = request( + &*self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + Some(pair), + ) + .await; + + if let Err(err) = resp { + self.stack.push(pair); + return Err(err); + } + let resp = resp.unwrap(); + + let length = pair.length(); + let mut count = 0; + let mut buf = [0; SIZE]; + let mut offset = pair.begin; + let mut reader = resp.into_body(); + loop { + // Reads some bytes from the byte stream. + // + // Returns the number of bytes read from the start of the buffer. + // + // If the return value is `Ok(n)`, then it must be guaranteed that + // `0 <= n <= buf.len()`. A nonzero `n` value indicates that the buffer has been + // filled in with `n` bytes of data. If `n` is `0`, then it can indicate one of two + // scenarios: + // + // 1. This reader has reached its "end of file" and will likely no longer be able to + // produce bytes. Note that this does not mean that the reader will always no + // longer be able to produce bytes. + // 2. The buffer specified was 0 bytes in length. + match reader.read(&mut buf).await { + Ok(0) => { + if count != length { + let pr = RangePair::new(offset, pair.end); + self.stack.push(pr); + return Err(Error::UncompletedRead); + } else { + return Ok(()); + } + } + Ok(len) => { + let pr = RangePair::new(offset, offset + len as u64 - 1); // The pair is a closed interval + let mut b = BytesMut::from(&buf[..len]); + if let Err(err) = self.sender.send((pr, b.to_bytes())).await { + let pr = RangePair::new(offset, pair.end); + self.stack.push(pr); + return Err(Error::InnerError(format!( + "Error at `http::RangeRequestTask`: Sender error: {:?}", + err + ))); + } + offset += len as u64; + count += len as u64; + } + Err(err) => { + let pr = RangePair::new(offset, pair.end); + self.stack.push(pr); + return Err(err.into()); + } + } + } + } +} diff --git a/src/app/core/m3u8/common.rs b/src/app/core/m3u8/common.rs new file mode 100644 index 0000000..690dcad --- /dev/null +++ b/src/app/core/m3u8/common.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; + +use m3u8_rs::{parse_playlist_res, playlist::Playlist}; + +use crate::common::{ + bytes::{ + bytes::{decode_hex, u32_to_u8x4}, + bytes_type::Bytes, + }, + errors::{Error, Result}, + list::SharedVec, + net::{ + net::{redirect, request}, + net_type::{HttpClient, Method, ResponseExt, Uri, Url}, + }, +}; + +#[derive(Debug, Clone)] +pub struct M3u8Segment { + pub index: u64, + pub method: Method, + pub uri: Uri, + pub data: Option, + pub key: Option<[u8; 16]>, + pub iv: Option<[u8; 16]>, +} + +impl M3u8Segment { + pub fn new( + index: u64, + method: Method, + uri: Uri, + data: Option, + key: Option<[u8; 16]>, + iv: Option<[u8; 16]>, + ) -> M3u8Segment { + M3u8Segment { + index, + method, + uri, + data, + key, + iv, + } + } +} + +pub type M3u8SegmentList = Vec; + +pub type SharedM3u8SegmentList = SharedVec; + +pub async fn get_m3u8( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, +) -> Result { + // uri -> (key, iv) + let mut keymap: HashMap = HashMap::new(); + let mut uris = vec![uri]; + let mut list = vec![]; + + while let Some(uri) = uris.pop() { + debug!("m3u8", uri); + let u = redirect(client, method.clone(), uri.clone(), data.clone()).await?; + debug!("m3u8 redirect to", u); + + if u != uri { + uris.push(u.clone()); + continue; + } + + let base_url = Url::parse(&format!("{:?}", u))?; + + // Read m3u8 content + let mut resp = request(client, method.clone(), u.clone(), data.clone(), None).await?; + + // Adding "\n" for the case when response content has not "\n" at end. + let cn = resp.text()? + "\n"; + + // Parse m3u8 content + let parsed = parse_playlist_res(cn.as_bytes()); + match parsed { + Ok(Playlist::MasterPlaylist(mut pl)) => { + pl.variants.reverse(); + for variant in &pl.variants { + let url = base_url.join(&variant.uri)?; + let uri: Uri = url.as_str().parse()?; + uris.push(uri); + } + } + Ok(Playlist::MediaPlaylist(pl)) => { + let mut index = pl.media_sequence as u64; + for segment in &pl.segments { + let seg_url = base_url.join(&segment.uri)?; + let seg_uri: Uri = seg_url.as_str().parse()?; + + let (key, iv) = if let Some(key) = &segment.key { + let iv = if let Some(iv) = &key.iv { + let mut i = [0; 16]; + let buf = decode_hex(&iv[2..])?; + i.clone_from_slice(&buf[..]); + i + } else { + let mut iv = [0; 16]; + let index_bin = u32_to_u8x4(index as u32); + iv[12..].clone_from_slice(&index_bin); + iv + }; + if let Some(uri) = &key.uri { + let key_url = base_url.join(&uri)?; + let key_uri: Uri = key_url.as_str().parse()?; + if let Some((k, iv)) = keymap.get(&key_uri) { + (Some(*k), Some(*iv)) + } else { + let k = get_key(client, Method::GET, key_uri.clone()).await?; + keymap.insert(key_uri.clone(), (k, iv)); + debug!("Get key, iv", (k, iv)); + (Some(k), Some(iv)) + } + } else { + (None, None) + } + } else { + (None, None) + }; + + list.push(M3u8Segment::new( + index, + Method::GET, + seg_uri.clone(), + None, + key, + iv, + )); + index += 1; + } + } + Err(_) => return Err(Error::M3U8ParseFail), + } + } + Ok(list) +} + +async fn get_key(client: &HttpClient, method: Method, uri: Uri) -> Result<[u8; 16]> { + let mut resp = request(client, method.clone(), uri.clone(), None, None).await?; + let mut cn = [0; 16]; + resp.copy_to(&mut cn[..])?; + Ok(cn) +} diff --git a/src/app/core/m3u8/m3u8.rs b/src/app/core/m3u8/m3u8.rs new file mode 100644 index 0000000..6059168 --- /dev/null +++ b/src/app/core/m3u8/m3u8.rs @@ -0,0 +1,229 @@ +use std::{ + path::PathBuf, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_std::{io::ReadExt, process::exit, task as std_task}; + +use futures::{ + channel::mpsc::{channel, Sender}, + SinkExt, +}; + +use crate::{ + app::{ + core::m3u8::common::{get_m3u8, M3u8Segment, SharedM3u8SegmentList}, + receive::m3u8_receiver::M3u8Receiver, + stats::list_stats::{ListStats, LISTSTATS_FILE_SUFFIX}, + }, + common::{ + bytes::bytes_type::{Buf, Bytes}, + crypto::decrypt_aes128, + errors::{Error, Result}, + net::{ + net::{build_http_client, request}, + net_type::{HttpClient, Method, Uri}, + }, + }, + features::{args::Args, running::Runnable, stack::StackLike}, +}; + +/// M3u8 task handler +pub struct M3u8Handler { + output: PathBuf, + method: Method, + uri: Uri, + headers: Vec<(String, String)>, + data: Option, + timeout: u64, + concurrency: u64, + proxy: Option, + client: Arc, +} + +impl M3u8Handler { + pub fn new(args: &impl Args) -> Result { + let headers = args.headers(); + let timeout = args.timeout(); + let proxy = args.proxy(); + + let hds: Vec<(&str, &str)> = headers + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + let client = build_http_client(hds.as_ref(), timeout, proxy.as_deref())?; + + debug!("M3u8Handler::new"); + + Ok(M3u8Handler { + output: args.output(), + method: args.method(), + uri: args.uri(), + headers, + data: args.data().map(|ref mut d| d.to_bytes()), + timeout, + concurrency: args.concurrency(), + proxy, + client: Arc::new(client), + }) + } + + async fn start(&mut self) -> Result<()> { + debug!("M3u8Handler::start"); + + // 0. Check whether task is completed + debug!("M3u8Handler: check whether task is completed"); + let mut liststats = + ListStats::new(&*(self.output.to_string_lossy() + LISTSTATS_FILE_SUFFIX))?; + if self.output.exists() && !liststats.exists() { + return Ok(()); + } + + // 1. Redirect + debug!("M3u8Handler: get m3u8"); + let mut ls = get_m3u8( + &self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + ) + .await?; + ls.reverse(); + + if liststats.exists() { + liststats.open()?; + let total = liststats.total()?; + if total != ls.len() as u64 { + return Err(Error::PartsAreNotConsistent); + } else { + let index = liststats.index()?; + ls.truncate((total - index) as usize); + } + } else { + liststats.open()?; + liststats.write_total(ls.len() as u64)?; + } + + let index = ls.last().unwrap().index; + let sharedindex = Arc::new(AtomicU64::new(index)); + let stack = SharedM3u8SegmentList::new(ls); + debug!("M3u8Handler: segments", stack.len()); + + // 4. Create channel + let (sender, receiver) = channel::<(u64, Bytes)>(self.concurrency as usize + 10); + + let concurrency = std::cmp::min(stack.len() as u64, self.concurrency); + for i in 1..concurrency + 1 { + let mut task = RequestTask::new( + self.client.clone(), + stack.clone(), + sender.clone(), + i, + sharedindex.clone(), + ); + std_task::spawn(async move { + task.start().await; + }); + } + drop(sender); // Remove the reference and let `Task` to handle it + + // 6. Create receiver + debug!("M3u8Handler: create receiver"); + let mut m3u8receiver = M3u8Receiver::new(&self.output)?; + m3u8receiver.start(receiver).await?; + + // 7. Task succeeds. Remove liststats file + liststats.remove().unwrap_or(()); // Missing error + Ok(()) + } +} + +impl Runnable for M3u8Handler { + fn run(&mut self) -> Result<()> { + std_task::block_on(self.start()) + } +} + +/// Request the resource with a range header which is in the `SharedRangList` +struct RequestTask { + client: Arc, + stack: SharedM3u8SegmentList, + sender: Sender<(u64, Bytes)>, + id: u64, + shared_index: Arc, +} + +impl RequestTask { + fn new( + client: Arc, + stack: SharedM3u8SegmentList, + sender: Sender<(u64, Bytes)>, + id: u64, + sharedindex: Arc, + ) -> RequestTask { + RequestTask { + client, + stack, + sender, + id, + shared_index: sharedindex, + } + } + + async fn start(&mut self) { + while let Some(segment) = self.stack.pop() { + loop { + match self.req(segment.clone()).await { + // Exit whole process when `Error::InnerError` is returned + Err(Error::InnerError(msg)) => { + print_err!(format!("RequestTask {}: InnerError", self.id), msg); + exit(1); + } + Err(err) => { + print_err!(format!("RequestTask {}: error", self.id), err); + std_task::sleep(Duration::from_secs(1)).await; + } + _ => break, + } + } + } + } + + async fn req(&mut self, segment: M3u8Segment) -> Result<()> { + let resp = request( + &*self.client, + segment.method.clone(), + segment.uri.clone(), + segment.data.clone(), + None, + ) + .await?; + + let index = segment.index; + let mut buf: Vec = vec![]; + let mut reader = resp.into_body(); + reader.read_to_end(&mut buf).await?; + if let (Some(key), Some(iv)) = (segment.key, segment.iv) { + buf = decrypt_aes128(&key[..], &iv[..], &buf[..])?; + } + + loop { + if self.shared_index.load(Ordering::SeqCst) == index { + if let Err(err) = self.sender.send((index, Bytes::from(buf))).await { + return Err(Error::InnerError(format!( + "Error at `http::RequestTask`: Sender error: {:?}", + err + ))); + } + self.shared_index.store(index + 1, Ordering::SeqCst); + return Ok(()); + } else { + std_task::sleep(Duration::from_millis(500)).await; + } + } + } +} diff --git a/src/app/core/m3u8/mod.rs b/src/app/core/m3u8/mod.rs new file mode 100644 index 0000000..10dfd4e --- /dev/null +++ b/src/app/core/m3u8/mod.rs @@ -0,0 +1,2 @@ +pub mod common; +pub mod m3u8; diff --git a/src/app/core/mod.rs b/src/app/core/mod.rs new file mode 100644 index 0000000..1330dcd --- /dev/null +++ b/src/app/core/mod.rs @@ -0,0 +1,2 @@ +pub mod http; +pub mod m3u8; diff --git a/src/app/mod.rs b/src/app/mod.rs new file mode 100644 index 0000000..561b64b --- /dev/null +++ b/src/app/mod.rs @@ -0,0 +1,5 @@ +pub mod core; +pub mod receive; +pub mod show; +pub mod stats; +pub mod status; diff --git a/src/app/receive/http_receiver.rs b/src/app/receive/http_receiver.rs new file mode 100644 index 0000000..d7ab4dd --- /dev/null +++ b/src/app/receive/http_receiver.rs @@ -0,0 +1,126 @@ +use std::{io::SeekFrom, path::Path, time::Duration}; + +use async_std::stream; +use futures::{channel::mpsc::Receiver, select, stream::StreamExt}; + +use crate::{ + app::{ + show::http_show::HttpShower, + stats::range_stats::{RangeStats, RANGESTATS_FILE_SUFFIX}, + status::rate_status::RateStatus, + }, + common::{bytes::bytes_type::Bytes, errors::Result, file::File, range::RangePair}, +}; + +pub struct HttpReceiver { + output: File, + rangestats: Option, + ratestatus: RateStatus, + shower: HttpShower, + // Total content length of the uri + total: u64, +} + +impl HttpReceiver { + pub fn new>(output: P, direct: bool) -> Result { + let mut outputfile = File::new(&output, true)?; + outputfile.open()?; + + let (rangestats, total, completed) = if direct { + (None, 0, 0) + } else { + let mut rangestats = + RangeStats::new(&*(output.as_ref().to_string_lossy() + RANGESTATS_FILE_SUFFIX))?; + rangestats.open()?; + let total = rangestats.total()?; + let completed = rangestats.count()?; + (Some(rangestats), total, completed) + }; + + let mut ratestatus = RateStatus::new(); + ratestatus.set_total(completed); + + Ok(HttpReceiver { + output: outputfile, + rangestats, + ratestatus, + shower: HttpShower::new(), + // receiver, + total, + }) + } + + fn show_infos(&mut self) -> Result<()> { + if self.rangestats.is_none() { + self.shower + .print_msg("Server doesn't support range request.")?; + } + + let file_name = &self.output.file_name().unwrap_or("[No Name]"); + let total = self.total; + self.shower.print_file(file_name)?; + self.shower.print_total(total)?; + // self.shower.print_concurrency(concurrency)?; + self.show_status()?; + Ok(()) + } + + fn show_status(&mut self) -> Result<()> { + let total = self.total; + let completed = self.ratestatus.total(); + let rate = self.ratestatus.rate(); + + let eta = if self.rangestats.is_some() { + let remains = total - completed; + // rate > 1.0 for overflow + if remains > 0 && rate > 1.0 { + let eta = (remains as f64 / rate) as u64; + // eta is large than 99 days, return 0 + if eta > 99 * 24 * 60 * 60 { + 0 + } else { + eta + } + } else { + 0 + } + } else { + 0 + }; + + self.shower.print_status(completed, total, rate, eta)?; + self.ratestatus.clean(); + Ok(()) + } + + fn record_pair(&mut self, pair: RangePair) -> Result<()> { + if let Some(ref mut rangestats) = self.rangestats { + rangestats.write_pair(pair)?; + } + Ok(()) + } + pub async fn start(&mut self, receiver: Receiver<(RangePair, Bytes)>) -> Result<()> { + self.show_infos()?; + + let mut tick = stream::interval(Duration::from_secs(2)).fuse(); + let mut receiver = receiver.fuse(); + loop { + select! { + item = receiver.next() => { + if let Some((pair, chunk)) = item { + self.output.write(&chunk[..], Some(SeekFrom::Start(pair.begin)))?; + self.record_pair(pair)?; + self.ratestatus.add(pair.length()); + } else { + break; + } + }, + _ = tick.next() => { + self.show_status()?; + }, + } + } + self.show_status()?; + Ok(()) + } +} diff --git a/src/app/receive/m3u8_receiver.rs b/src/app/receive/m3u8_receiver.rs new file mode 100644 index 0000000..e84a003 --- /dev/null +++ b/src/app/receive/m3u8_receiver.rs @@ -0,0 +1,90 @@ +use std::{io::SeekFrom, path::Path, time::Duration}; + +use async_std::stream; +use futures::{channel::mpsc::Receiver, select, stream::StreamExt}; + +use crate::{ + app::{ + show::m3u8_show::M3u8Shower, + stats::list_stats::{ListStats, LISTSTATS_FILE_SUFFIX}, + status::rate_status::RateStatus, + }, + common::{bytes::bytes_type::Bytes, errors::Result, file::File}, +}; + +pub struct M3u8Receiver { + output: File, + liststats: ListStats, + ratestatus: RateStatus, + shower: M3u8Shower, + // Total number of the `SharedM3u8SegmentList` + total: u64, + completed: u64, +} + +impl M3u8Receiver { + pub fn new>(output: P) -> Result { + let mut outputfile = File::new(&output, true)?; + outputfile.open()?; + + let mut liststats = + ListStats::new(&*(output.as_ref().to_string_lossy() + LISTSTATS_FILE_SUFFIX))?; + liststats.open()?; + let total = liststats.total()?; + let completed = liststats.index()?; + + Ok(M3u8Receiver { + output: outputfile, + liststats, + ratestatus: RateStatus::new(), + shower: M3u8Shower::new(), + total, + completed, + }) + } + + fn show_infos(&mut self) -> Result<()> { + let file_name = &self.output.file_name().unwrap_or("[No Name]"); + let total = self.total; + self.shower.print_file(file_name)?; + self.shower.print_total(total)?; + self.show_status()?; + Ok(()) + } + + fn show_status(&mut self) -> Result<()> { + let total = self.total; + let completed = self.completed; + let rate = self.ratestatus.rate(); + + self.shower.print_status(completed, total, rate)?; + self.ratestatus.clean(); + Ok(()) + } + + pub async fn start(&mut self, receiver: Receiver<(u64, Bytes)>) -> Result<()> { + self.show_infos()?; + + let mut tick = stream::interval(Duration::from_secs(2)).fuse(); + let mut receiver = receiver.fuse(); + loop { + select! { + item = receiver.next() => { + if let Some((index, chunk)) = item { + self.output.write(&chunk[..], Some(SeekFrom::End(0)))?; + self.liststats.write_index(index + 1)?; + self.ratestatus.add(chunk.len() as u64); + self.completed = index + 1; + } else { + break; + } + }, + _ = tick.next() => { + self.show_status()?; + }, + } + } + self.show_status()?; + Ok(()) + } +} diff --git a/src/app/receive/mod.rs b/src/app/receive/mod.rs new file mode 100644 index 0000000..19dc232 --- /dev/null +++ b/src/app/receive/mod.rs @@ -0,0 +1,2 @@ +pub mod http_receiver; +pub mod m3u8_receiver; diff --git a/src/app/show/common.rs b/src/app/show/common.rs new file mode 100644 index 0000000..2eb5369 --- /dev/null +++ b/src/app/show/common.rs @@ -0,0 +1,11 @@ +#[cfg(target_os = "windows")] +pub fn bars() -> (&'static str, &'static str, &'static str) { + // bar, bar_right, bar_left + ("▓", "░", " ") +} + +#[cfg(not(target_os = "windows"))] +pub fn bars() -> (&'static str, &'static str, &'static str) { + // bar, bar_right, bar_left + ("━", "╸", "╺") +} diff --git a/src/app/show/http_show.rs b/src/app/show/http_show.rs new file mode 100644 index 0000000..01e56e2 --- /dev/null +++ b/src/app/show/http_show.rs @@ -0,0 +1,123 @@ +use std::io::{stdout, Stdout, Write}; + +use crate::{ + app::show::common::bars, + common::{ + colors::{Black, Blue, Cyan, Green, Red, Yellow}, + errors::Result, + liberal::ToDate, + size::HumanReadable, + terminal::terminal_width, + }, +}; + +pub struct HttpShower { + stdout: Stdout, +} + +impl HttpShower { + pub fn new() -> HttpShower { + HttpShower { stdout: stdout() } + } + + pub fn print_msg(&mut self, msg: &str) -> Result<()> { + writeln!(&mut self.stdout, "\n {}", Yellow.italic().paint(msg))?; + Ok(()) + } + + pub fn print_file(&mut self, path: &str) -> Result<()> { + writeln!( + &mut self.stdout, + // "\n {}: {}", + "\n{}: {}", + Green.bold().paint("File"), + path, + )?; + Ok(()) + } + + pub fn print_total(&mut self, total: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {} ({})", + Blue.bold().paint("Length"), + total.human_readable(), + total, + )?; + Ok(()) + } + + pub fn print_concurrency(&mut self, concurrency: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {}\n", + Yellow.bold().paint("concurrency"), + concurrency, + )?; + Ok(()) + } + + pub fn print_status(&mut self, completed: u64, total: u64, rate: f64, eta: u64) -> Result<()> { + let percent = completed as f64 / total as f64; + + let completed_str = completed.human_readable(); + let total_str = total.human_readable(); + let percent_str = format!("{:.2}", percent * 100.0); + let rate_str = rate.human_readable(); + let eta_str = eta.date(); + + // maximum info length is 41 e.g. + // 1001.3k/1021.9m 97.98% 1003.1B/s eta: 12s + let info = format!( + "{completed}/{total} {percent}% {rate}/s eta: {eta}", + completed = completed_str, + total = total_str, + percent = percent_str, + rate = rate_str, + eta = eta_str, + ); + + // set default info length + let info_length = 41; + let miss = info_length - info.len(); + + let terminal_width = terminal_width(); + let bar_length = terminal_width - info_length as u64 - 3; + let process_bar_length = (bar_length as f64 * percent) as u64; + let blank_length = bar_length - process_bar_length; + + let (bar, bar_right, bar_left) = bars(); + + let bar_done_str = if process_bar_length > 0 { + format!( + "{}{}", + bar.repeat((process_bar_length - 1) as usize), + bar_right + ) + } else { + "".to_owned() + }; + let bar_undone_str = if blank_length > 0 { + format!("{}{}", bar_left, bar.repeat(blank_length as usize - 1)) + } else { + "".to_owned() + }; + + write!( + &mut self.stdout, + "\r{completed}/{total} {percent}% {rate}/s eta: {eta}{miss} {process_bar}{blank} ", + completed = Red.bold().paint(completed_str), + total = Green.bold().paint(total_str), + percent = Yellow.bold().paint(percent_str), + rate = Blue.bold().paint(rate_str), + eta = Cyan.bold().paint(eta_str), + miss = " ".repeat(miss), + process_bar = Red.bold().paint(bar_done_str), + blank = Black.bold().paint(bar_undone_str), + )?; + + self.stdout.flush()?; + + Ok(()) + } +} diff --git a/src/app/show/m3u8_show.rs b/src/app/show/m3u8_show.rs new file mode 100644 index 0000000..719b687 --- /dev/null +++ b/src/app/show/m3u8_show.rs @@ -0,0 +1,119 @@ +use std::io::{stdout, Stdout, Write}; + +use crate::{ + app::show::common::bars, + common::{ + colors::{Black, Blue, Green, Red, Yellow}, + errors::Result, + size::HumanReadable, + terminal::terminal_width, + }, +}; + +pub struct M3u8Shower { + stdout: Stdout, +} + +impl M3u8Shower { + pub fn new() -> M3u8Shower { + M3u8Shower { stdout: stdout() } + } + + pub fn print_msg(&mut self, msg: &str) -> Result<()> { + writeln!(&mut self.stdout, "\n {}", Yellow.italic().paint(msg))?; + Ok(()) + } + + pub fn print_file(&mut self, path: &str) -> Result<()> { + writeln!( + &mut self.stdout, + // "\n {}: {}", + "\n{}: {}", + Green.bold().paint("File"), + path, + )?; + Ok(()) + } + + pub fn print_total(&mut self, total: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {}", + Blue.bold().paint("Segments"), + total, + )?; + Ok(()) + } + + pub fn print_concurrency(&mut self, concurrency: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {}\n", + Yellow.bold().paint("concurrency"), + concurrency, + )?; + Ok(()) + } + + pub fn print_status(&mut self, completed: u64, total: u64, rate: f64) -> Result<()> { + let percent = completed as f64 / total as f64; + + let completed_str = completed.to_string(); + let total_str = total.to_string(); + let percent_str = format!("{:.2}", percent * 100.0); + let rate_str = rate.human_readable(); + + // maximum info length is `completed_str.len()` + `total_str.len()` + 19 + // e.g. + // 100/1021 97.98% 1003.1B/s eta: 12s + let info = format!( + "{completed}/{total} {percent}% {rate}/s", + completed = completed_str, + total = total_str, + percent = percent_str, + rate = rate_str, + ); + + // set default info length + let info_length = completed_str.len() + total_str.len() + 19; + let miss = info_length - info.len(); + + let terminal_width = terminal_width(); + let bar_length = terminal_width - info_length as u64 - 3; + let process_bar_length = (bar_length as f64 * percent) as u64; + let blank_length = bar_length - process_bar_length; + + let (bar, bar_right, bar_left) = bars(); + + let bar_done_str = if process_bar_length > 0 { + format!( + "{}{}", + bar.repeat((process_bar_length - 1) as usize), + bar_right + ) + } else { + "".to_owned() + }; + let bar_undone_str = if blank_length > 0 { + format!("{}{}", bar_left, bar.repeat(blank_length as usize - 1)) + } else { + "".to_owned() + }; + + write!( + &mut self.stdout, + "\r{completed}/{total} {percent}% {rate}/s{miss} {process_bar}{blank} ", + completed = Red.bold().paint(completed_str), + total = Green.bold().paint(total_str), + percent = Yellow.bold().paint(percent_str), + rate = Blue.bold().paint(rate_str), + miss = " ".repeat(miss), + process_bar = Red.bold().paint(bar_done_str), + blank = Black.bold().paint(bar_undone_str), + )?; + + self.stdout.flush()?; + + Ok(()) + } +} diff --git a/src/app/show/mod.rs b/src/app/show/mod.rs new file mode 100644 index 0000000..87fee76 --- /dev/null +++ b/src/app/show/mod.rs @@ -0,0 +1,3 @@ +pub mod common; +pub mod http_show; +pub mod m3u8_show; diff --git a/src/app/stats/list_stats.rs b/src/app/stats/list_stats.rs new file mode 100644 index 0000000..fc11ff1 --- /dev/null +++ b/src/app/stats/list_stats.rs @@ -0,0 +1,72 @@ +use std::{io::SeekFrom, path::Path}; + +use crate::common::{ + bytes::bytes::{u64_to_u8x8, u8x8_to_u64}, + errors::Result, + file::File, +}; + +pub const LISTSTATS_FILE_SUFFIX: &'static str = ".ls.aget"; + +/// List statistic +/// +/// `ListStats` struct records total and index two number. +/// All information is stored at a local file. +/// +/// [total 8bit][index 8bit] +/// `total` is given by user, presenting as the real total number of items of a list. +pub struct ListStats { + inner: File, +} + +impl ListStats { + pub fn new>(path: P) -> Result { + let inner = File::new(path, true)?; + Ok(ListStats { inner }) + } + + pub fn open(&mut self) -> Result<&mut Self> { + self.inner.open()?; + Ok(self) + } + + pub fn file_name(&self) -> Option<&str> { + self.inner.file_name() + } + + pub fn exists(&self) -> bool { + self.inner.exists() + } + + /// Delete the inner file + pub fn remove(&self) -> Result<()> { + self.inner.remove() + } + + /// Get downloading file's content length stored in the aget file + pub fn total(&mut self) -> Result { + let mut buf: [u8; 8] = [0; 8]; + self.inner.read(&mut buf, Some(SeekFrom::Start(0)))?; + let cl = u8x8_to_u64(&buf); + Ok(cl) + } + + pub fn index(&mut self) -> Result { + let mut buf: [u8; 8] = [0; 8]; + self.inner.read(&mut buf, Some(SeekFrom::Start(8)))?; + let cl = u8x8_to_u64(&buf); + Ok(cl) + } + + pub fn write_total(&mut self, total: u64) -> Result<()> { + let buf = u64_to_u8x8(total); + self.inner.write(&buf, Some(SeekFrom::Start(0)))?; + Ok(()) + } + + pub fn write_index(&mut self, index: u64) -> Result<()> { + let buf = u64_to_u8x8(index); + self.inner.write(&buf, Some(SeekFrom::Start(8)))?; + Ok(()) + } +} diff --git a/src/app/stats/mod.rs b/src/app/stats/mod.rs new file mode 100644 index 0000000..57fa9e9 --- /dev/null +++ b/src/app/stats/mod.rs @@ -0,0 +1,2 @@ +pub mod list_stats; +pub mod range_stats; diff --git a/src/app/stats/range_stats.rs b/src/app/stats/range_stats.rs new file mode 100644 index 0000000..a12ccba --- /dev/null +++ b/src/app/stats/range_stats.rs @@ -0,0 +1,194 @@ +use std::{cmp::max, io::SeekFrom, path::Path}; + +use crate::common::{ + bytes::bytes::{u64_to_u8x8, u8x8_to_u64}, + errors::Result, + file::File, + range::{RangeList, RangePair}, +}; + +pub const RANGESTATS_FILE_SUFFIX: &'static str = ".rg.aget"; + +/// Range statistic +/// +/// This struct records pairs which are `common::range::RangePair`. +/// All information is stored at a local file. +/// +/// [total 8bit][ [begin1 8bit,end1 8bit] [begin2 8bit,end2 8bit] ... ] +/// `total` position is not sum_i{end_i - begin_i + 1}. It is given by +/// user, presenting as the real total number. +pub struct RangeStats { + inner: File, +} + +impl RangeStats { + pub fn new>(path: P) -> Result { + let inner = File::new(path, true)?; + Ok(RangeStats { inner }) + } + + pub fn open(&mut self) -> Result<&mut Self> { + self.inner.open()?; + Ok(self) + } + + pub fn file_name(&self) -> Option<&str> { + self.inner.file_name() + } + + pub fn exists(&self) -> bool { + self.inner.exists() + } + + /// Delete the inner file + pub fn remove(&self) -> Result<()> { + self.inner.remove() + } + + /// Get downloading file's content length stored in the aget file + pub fn total(&mut self) -> Result { + let mut buf: [u8; 8] = [0; 8]; + self.inner.read(&mut buf, Some(SeekFrom::Start(0)))?; + let cl = u8x8_to_u64(&buf); + Ok(cl) + } + + /// Count the length of total pairs + pub fn count(&mut self) -> Result { + let pairs = self.pairs()?; + if pairs.is_empty() { + Ok(0) + } else { + Ok(pairs.iter().map(RangePair::length).sum()) + } + } + + /// Recorded pairs + pub fn pairs(&mut self) -> Result { + let mut pairs: Vec<(u64, u64)> = Vec::new(); + + let mut buf: [u8; 16] = [0; 16]; + self.inner.seek(SeekFrom::Start(8))?; + loop { + let s = self.inner.read(&mut buf, None)?; + if s != 16 { + break; + } + + let mut raw = [0; 8]; + raw.clone_from_slice(&buf[..8]); + let begin = u8x8_to_u64(&raw); + raw.clone_from_slice(&buf[8..]); + let end = u8x8_to_u64(&raw); + + assert!( + begin <= end, + format!( + "Bug: `begin > end` in an pair of {}. : {} > {}", + self.file_name().unwrap_or(""), + begin, + end + ) + ); + + pairs.push((begin, end)); + } + + pairs.sort(); + + // merge pairs + let mut merged_pairs: Vec<(u64, u64)> = Vec::new(); + if !pairs.is_empty() { + merged_pairs.push(pairs[0]); + } + for (begin, end) in pairs.iter() { + let (pre_start, pre_end) = merged_pairs.last().unwrap().clone(); + + // case 1 + // ---------- + // ----------- + if pre_end + 1 < *begin { + merged_pairs.push((*begin, *end)); + // case 2 + // ----------------- + // ---------- + // -------- + // ------ + } else { + let n_start = pre_start; + let n_end = max(pre_end, *end); + merged_pairs.pop(); + merged_pairs.push((n_start, n_end)); + } + } + + Ok(merged_pairs + .iter() + .map(|(begin, end)| RangePair::new(*begin, *end)) + .collect::()) + } + + /// Get gaps between all pairs + /// Each of gap is a closed interval + pub fn gaps(&mut self) -> Result { + let mut pairs = self.pairs()?; + let total = self.total()?; + pairs.push(RangePair::new(total, total)); + + // find gaps + let mut gaps: RangeList = Vec::new(); + // find first chunk + let RangePair { begin, .. } = pairs[0]; + if begin > 0 { + gaps.push(RangePair::new(0, begin - 1)); + } + + for (index, RangePair { end, .. }) in pairs.iter().enumerate() { + if let Some(RangePair { + begin: next_start, .. + }) = pairs.get(index + 1) + { + if end + 1 < *next_start { + gaps.push(RangePair::new(end + 1, next_start - 1)); + } + } + } + + Ok(gaps) + } + + pub fn write_total(&mut self, total: u64) -> Result<()> { + let buf = u64_to_u8x8(total); + self.inner.write(&buf, Some(SeekFrom::Start(0)))?; + Ok(()) + } + + pub fn write_pair(&mut self, pair: RangePair) -> Result<()> { + let begin = u64_to_u8x8(pair.begin); + let end = u64_to_u8x8(pair.end); + let buf = [begin, end].concat(); + self.inner.write(&buf, Some(SeekFrom::End(0)))?; + Ok(()) + } + + // Merge completed pairs and rewrite the aget file + pub fn rewrite(&mut self) -> Result<()> { + let total = self.total()?; + let pairs = self.pairs()?; + + let mut buf: Vec = Vec::new(); + buf.extend(&u64_to_u8x8(total)); + for pair in pairs.iter() { + buf.extend(&u64_to_u8x8(pair.begin)); + buf.extend(&u64_to_u8x8(pair.end)); + } + + // Clean all content of the file and set its length to zero + self.inner.set_len(0)?; + + // Write new data + self.inner.write(buf.as_slice(), Some(SeekFrom::Start(0)))?; + + Ok(()) + } +} diff --git a/src/app/status/mod.rs b/src/app/status/mod.rs new file mode 100644 index 0000000..2d89a27 --- /dev/null +++ b/src/app/status/mod.rs @@ -0,0 +1 @@ +pub mod rate_status; diff --git a/src/app/status/rate_status.rs b/src/app/status/rate_status.rs new file mode 100644 index 0000000..8771ddb --- /dev/null +++ b/src/app/status/rate_status.rs @@ -0,0 +1,63 @@ +use std::time::Instant; + +/// `RateStatus` records the rate of adding number +pub struct RateStatus { + /// Total number + total: u64, + + /// The number at an one tick interval + count: u64, + + /// The interval of one tick + tick: Instant, +} + +impl RateStatus { + pub fn new() -> RateStatus { + RateStatus::default() + } + + pub fn total(&self) -> u64 { + self.total + } + + pub fn set_total(&mut self, total: u64) { + self.total = total; + } + + pub fn count(&self) -> u64 { + self.count + } + + pub fn rate(&self) -> f64 { + let interval = self.tick.elapsed().as_secs_f64(); + let rate = self.count as f64 / interval; + rate + } + + pub fn add(&mut self, incr: u64) { + self.total += incr; + self.count += incr; + } + + pub fn reset(&mut self) { + self.total = 0; + self.count = 0; + self.tick = Instant::now(); + } + + pub fn clean(&mut self) { + self.count = 0; + self.tick = Instant::now(); + } +} + +impl Default for RateStatus { + fn default() -> RateStatus { + RateStatus { + total: 0, + count: 0, + tick: Instant::now(), + } + } +} diff --git a/src/clap_app.rs b/src/arguments/clap_app.rs similarity index 87% rename from src/clap_app.rs rename to src/arguments/clap_app.rs index 1c337ab..b86c434 100644 --- a/src/clap_app.rs +++ b/src/arguments/clap_app.rs @@ -1,6 +1,6 @@ use clap::{crate_name, crate_version, App as ClapApp, AppSettings, Arg}; -pub fn build_app() -> ClapApp<'static, 'static> { +pub fn build_app<'a>() -> ClapApp<'a, 'a> { ClapApp::new(crate_name!()) .version(crate_version!()) .global_setting(AppSettings::ColoredHelp) @@ -66,12 +66,18 @@ pub fn build_app() -> ClapApp<'static, 'static> { .multiple(false) .takes_value(true), ) + .arg( + Arg::with_name("proxy") + .long("proxy") + .help("proxy (http/https/socks4/socks5) e.g. -p http://localhost:1024") + .multiple(false) + .takes_value(true), + ) .arg( Arg::with_name("timeout") .short("t") .long("timeout") .help("Timeout(seconds) of request") - .default_value("10") .multiple(false) .takes_value(true), ) @@ -91,6 +97,13 @@ pub fn build_app() -> ClapApp<'static, 'static> { .multiple(false) .takes_value(true), ) + .arg( + Arg::with_name("type") + .long("type") + .default_value("http") + .multiple(false) + .help("Task type, http/m3u8"), + ) .arg( Arg::with_name("debug") .long("debug") diff --git a/src/arguments/cmd_args.rs b/src/arguments/cmd_args.rs new file mode 100644 index 0000000..6ccb0a7 --- /dev/null +++ b/src/arguments/cmd_args.rs @@ -0,0 +1,239 @@ +use std::{ + env, fmt, + path::{Path, PathBuf}, +}; + +#[cfg(windows)] +use ansi_term::enable_ansi_support; + +use clap::ArgMatches; + +use percent_encoding::percent_decode; + +use crate::{ + arguments::clap_app::build_app, + common::{ + bytes::bytes_type::BytesMut, + character::escape_nonascii, + errors::Error, + liberal::ParseLiteralNumber, + net::{ + net::parse_headers, + net_type::{Method, Uri}, + }, + tasks::TaskType, + }, + features::args::Args, +}; + +pub struct CmdArgs { + matches: ArgMatches<'static>, +} + +impl CmdArgs { + pub fn new() -> CmdArgs { + #[cfg(windows)] + let _ = enable_ansi_support(); + + let args = env::args(); + let inner = build_app(); + let matches = inner.get_matches_from(args); + CmdArgs { matches } + } +} + +impl Args for CmdArgs { + /// Path of output + fn output(&self) -> PathBuf { + if let Some(path) = self.matches.value_of("out") { + PathBuf::from(path) + } else { + let uri = self.uri(); + let path = Path::new(uri.path()); + if let Some(file_name) = path.file_name() { + PathBuf::from( + percent_decode(file_name.to_str().unwrap().as_bytes()) + .decode_utf8() + .unwrap() + .to_string(), + ) + } else { + panic!("{:?}", Error::NoFilename); + } + } + } + + /// Request method for http + fn method(&self) -> Method { + if let Some(method) = self.matches.value_of("method") { + match method.to_uppercase().as_str() { + "GET" => Method::GET, + "POST" => Method::POST, + _ => panic!(format!( + "{:?}", + Error::UnsupportedMethod(method.to_string()) + )), + } + } else { + if self.data().is_some() { + Method::POST + } else { + Method::GET + } + } + } + + /// The uri of a task + fn uri(&self) -> Uri { + self.matches + .value_of("URL") + .map(escape_nonascii) + .unwrap() + .parse() + .unwrap() + } + + /// The data for http post request + fn data(&self) -> Option { + self.matches.value_of("data").map(|d| BytesMut::from(d)) + } + + /// Request headers + fn headers(&self) -> Vec<(String, String)> { + if let Some(headers) = self.matches.values_of("header") { + parse_headers(headers) + .unwrap() + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect::>() + } else { + vec![] + } + } + + /// Set proxy througth arg or environment variable + /// + /// The environment variables can be: + /// http_proxy [protocol://][:port] + /// Sets the proxy server to use for HTTP. + /// + /// HTTPS_PROXY [protocol://][:port] + /// Sets the proxy server to use for HTTPS. + /// + /// ALL_PROXY [protocol://][:port] + /// Sets the proxy server to use if no protocol-specific proxy is set. + /// + /// Protocols: + /// http:// + /// an HTTP proxy + /// https:// + /// as HTTPS proxy + /// socks4:// + /// socks4a:// + /// socks5:// + /// socks5h:// + /// as SOCKS proxy + fn proxy(&self) -> Option { + let p = self.matches.value_of("proxy").map(|i| i.to_string()); + if p.is_some() { + return p; + } + + if let Ok(p) = env::var("http_proxy") { + return Some(p); + } + + if let Ok(p) = env::var("HTTPS_PROXY") { + return Some(p); + } + + if let Ok(p) = env::var("ALL_PROXY") { + return Some(p); + } + + None + } + + /// The maximum time the request is allowed to take. + fn timeout(&self) -> u64 { + self.matches + .value_of("timeout") + .map(|i| i.parse::().unwrap()) + .unwrap_or(0) + } + + /// The number of concurrency + fn concurrency(&self) -> u64 { + self.matches + .value_of("concurrency") + .map(|i| i.parse::().unwrap()) + .unwrap_or(10) + } + + /// The chunk length of each concurrency for http task + fn chunk_length(&self) -> u64 { + self.matches + .value_of("chunk-length") + .map(|i| i.literal_number().unwrap()) + .unwrap_or(1024 * 500) // 500k + } + + /// The number of retry of a task, default is 5 + fn retries(&self) -> u64 { + self.matches + .value_of("retries") + .map(|i| i.parse::().unwrap()) + .unwrap_or(5) + } + + /// The internal of each retry, default is zero + fn retry_wait(&self) -> u64 { + self.matches + .value_of("retry_wait") + .map(|i| i.parse::().unwrap()) + .unwrap_or(0) + } + + /// Task type + fn task_type(&self) -> TaskType { + self.matches + .value_of("type") + .map(|i| match i.to_lowercase().as_str() { + "http" => TaskType::HTTP, + "m3u8" => TaskType::M3U8, + _ => panic!(format!("{:?}", Error::UnsupportedTask(i.to_string()))), + }) + .unwrap_or(TaskType::HTTP) + } + + /// To debug mode, if it returns true + fn debug(&self) -> bool { + self.matches.is_present("debug") + } + + /// To quiet mode, if it return true + fn quiet(&self) -> bool { + self.matches.is_present("quiet") + } +} + +impl fmt::Debug for CmdArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CmdArgs") + .field("output", &self.output()) + .field("method", &self.method()) + .field("uri", &self.uri()) + .field("data", &self.data()) + .field("headers", &self.headers()) + .field("proxy", &self.proxy()) + .field("timeout", &self.timeout()) + .field("concurrency", &self.concurrency()) + .field("chunk_length", &self.chunk_length()) + .field("retries", &self.retries()) + .field("retry_wait", &self.retry_wait()) + .field("task_type", &self.task_type()) + .field("debug", &self.debug()) + .field("quiet", &self.quiet()) + .finish() + } +} diff --git a/src/arguments/mod.rs b/src/arguments/mod.rs new file mode 100644 index 0000000..351f335 --- /dev/null +++ b/src/arguments/mod.rs @@ -0,0 +1,2 @@ +pub mod clap_app; +pub mod cmd_args; diff --git a/src/chunk.rs b/src/chunk.rs deleted file mode 100644 index c0ea516..0000000 --- a/src/chunk.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -#[derive(Debug, Clone)] -pub struct RangePart { - pub start: u64, - pub end: u64, -} - -impl RangePart { - pub fn new(start: u64, end: u64) -> RangePart { - RangePart { start, end } - } - - pub fn length(&self) -> u64 { - self.end - self.start + 1 - } -} - -pub type RangeStack = Rc>>; - -/// Split a close `interval` to many piece chunk that its size is equal to `chunk_length`, -/// but the last piece size can be less then `chunk_length`. -pub fn make_range_chunks(interval: &RangePart, chunk_length: u64) -> Vec { - let mut stack = Vec::new(); - - let mut start = interval.start; - let interval_end = interval.end; - - while start + chunk_length - 1 <= interval_end { - let end = start + chunk_length - 1; - stack.push(RangePart::new(start, end)); - start += chunk_length; - } - - if start <= interval_end { - stack.push(RangePart::new(start, interval_end)); - } - - stack -} diff --git a/src/common.rs b/src/common.rs deleted file mode 100644 index e16d9ec..0000000 --- a/src/common.rs +++ /dev/null @@ -1 +0,0 @@ -pub const AGET_EXT: &'static str = ".aget"; diff --git a/src/common/buf.rs b/src/common/buf.rs new file mode 100644 index 0000000..a979a83 --- /dev/null +++ b/src/common/buf.rs @@ -0,0 +1,2 @@ +/// Default buffer size +pub const SIZE: usize = 2048; diff --git a/src/common/bytes/bytes.rs b/src/common/bytes/bytes.rs new file mode 100644 index 0000000..bc70d79 --- /dev/null +++ b/src/common/bytes/bytes.rs @@ -0,0 +1,25 @@ +use std::num::ParseIntError; + +use crate::common::errors::Result; + +/// Create an integer value from its representation as a byte array in big endian. +pub fn u8x8_to_u64(u8x8: &[u8; 8]) -> u64 { + u64::from_be_bytes(*u8x8) +} + +/// Return the memory representation of this integer as a byte array in big-endian (network) byte +/// order. +pub fn u64_to_u8x8(u: u64) -> [u8; 8] { + u.to_be_bytes() +} + +pub fn u32_to_u8x4(u: u32) -> [u8; 4] { + u.to_be_bytes() +} + +pub fn decode_hex(s: &str) -> Result, ParseIntError> { + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16)) + .collect() +} diff --git a/src/common/bytes/bytes_type.rs b/src/common/bytes/bytes_type.rs new file mode 100644 index 0000000..1bda7da --- /dev/null +++ b/src/common/bytes/bytes_type.rs @@ -0,0 +1 @@ +pub use bytes::{Buf, BufMut, Bytes, BytesMut}; diff --git a/src/common/bytes/mod.rs b/src/common/bytes/mod.rs new file mode 100644 index 0000000..2dcdc9b --- /dev/null +++ b/src/common/bytes/mod.rs @@ -0,0 +1,2 @@ +pub mod bytes; +pub mod bytes_type; diff --git a/src/common/character.rs b/src/common/character.rs new file mode 100644 index 0000000..80a8702 --- /dev/null +++ b/src/common/character.rs @@ -0,0 +1,24 @@ +use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; + +const FRAGMENT: &AsciiSet = &CONTROLS + .add(b' ') + .remove(b'?') + .remove(b'/') + .remove(b':') + .remove(b'='); + +pub fn escape_nonascii(target: &str) -> String { + utf8_percent_encode(target, FRAGMENT).to_string() +} + +#[cfg(test)] +mod tests { + use super::escape_nonascii; + + #[test] + fn test_escape_nonascii() { + let s = ":ss/s 来;】/ 【【 ? 是的 & 水电费=45 进来看"; + println!("{}", s); + println!("{}", escape_nonascii(s)); + } +} diff --git a/src/common/colors.rs b/src/common/colors.rs new file mode 100644 index 0000000..db103b7 --- /dev/null +++ b/src/common/colors.rs @@ -0,0 +1 @@ +pub use ansi_term::Colour::*; diff --git a/src/common/crypto.rs b/src/common/crypto.rs new file mode 100644 index 0000000..08e774a --- /dev/null +++ b/src/common/crypto.rs @@ -0,0 +1,8 @@ +use openssl::symm::{decrypt, Cipher}; + +use crate::common::errors::Result; + +pub fn decrypt_aes128(key: &[u8], iv: &[u8], enc: &[u8]) -> Result> { + let cipher = Cipher::aes_128_cbc(); + Ok(decrypt(cipher, key, Some(iv), enc)?) +} diff --git a/src/common/debug.rs b/src/common/debug.rs new file mode 100644 index 0000000..57d7ce9 --- /dev/null +++ b/src/common/debug.rs @@ -0,0 +1,27 @@ +pub static mut DEBUG: bool = false; +pub static mut QUIET: bool = false; + +#[macro_export] +macro_rules! print_err { + ( $ctx:expr, $err:expr ) => { + eprintln!("[{}:{}] {}: {}", file!(), line!(), $ctx, $err); + }; +} + +#[macro_export] +macro_rules! debug { + ( $title:expr, $msg:expr ) => { + unsafe { + if crate::common::debug::DEBUG { + eprintln!("[{}:{}] {}: {:#?}", file!(), line!(), $title, $msg); + } + } + }; + ( $title:expr ) => { + unsafe { + if crate::common::debug::DEBUG { + eprintln!("[{}:{}] {}", file!(), line!(), $title); + } + } + }; +} diff --git a/src/common/errors.rs b/src/common/errors.rs new file mode 100644 index 0000000..8b7f809 --- /dev/null +++ b/src/common/errors.rs @@ -0,0 +1,91 @@ +use std::{io::Error as IoError, num, result}; + +use thiserror::Error as ThisError; + +use isahc::{http, Error as IsahcError}; + +use url::ParseError as UrlParseError; + +use openssl; + +pub type Result = result::Result; + +#[derive(Debug, ThisError)] +pub enum Error { + // For Arguments + #[error("Output path is invalid: {0}")] + InvalidPath(String), + #[error("Uri is invalid: {0}")] + InvaildUri(#[from] http::uri::InvalidUri), + #[error("Header is invalid: {0}")] + InvalidHeader(String), + #[error("No filename.")] + NoFilename, + #[error("Directory is not found")] + NotFoundDirectory, + #[error("The file already exists.")] + FileExists, + #[error("The path is a directory.")] + PathIsDirectory, + #[error("Can't parse string as number: {0}")] + IsNotNumber(#[from] num::ParseIntError), + #[error("Io Error: {0}")] + Io(#[from] IoError), + #[error("{0} task is not supported")] + UnsupportedTask(String), + + // For Network + #[error("Network error: {0}")] + NetError(String), + #[error("Uncompleted Read")] + UncompletedRead, + #[error("{0} is unsupported")] + UnsupportedMethod(String), + #[error("header is invalid: {0}")] + HeaderParseError(String), + #[error("header is invalid: {0}")] + UrlParseError(#[from] UrlParseError), + #[error("BUG: {0}")] + Bug(String), + #[error("The two content lengths are not equal between the response and the aget file.")] + ContentLengthIsNotConsistent, + + // For m3u8 + #[error("Fail to parse m3u8 file.")] + M3U8ParseFail, + #[error("The two m3u8 parts are not equal between the response and the aget file.")] + PartsAreNotConsistent, + + #[error("An internal error: {0}")] + InnerError(String), + #[error("Content does not has length")] + NoContentLength, + #[error("header is invalid: {0}")] + InvaildHeader(String), + #[error("response status code is: {0}")] + Unsuccess(u16), + #[error("Redirect to: {0}")] + Redirect(String), + #[error("No Location for redirection: {0}")] + NoLocation(String), + #[error("Fail to decrypt aes128 data: {0}")] + AES128DecryptFail(#[from] openssl::error::ErrorStack), +} + +impl From for Error { + fn from(err: IsahcError) -> Error { + Error::NetError(format!("{}", err)) + } +} + +impl From for Error { + fn from(err: http::header::ToStrError) -> Error { + Error::NetError(format!("{}", err)) + } +} + +impl From for Error { + fn from(err: http::Error) -> Error { + Error::NetError(format!("{}", err)) + } +} diff --git a/src/common/file.rs b/src/common/file.rs new file mode 100644 index 0000000..1b541fe --- /dev/null +++ b/src/common/file.rs @@ -0,0 +1,107 @@ +use std::{ + fs::{create_dir_all, metadata, remove_file, File as StdFile, OpenOptions}, + io::{Read, Seek, SeekFrom, Write}, + path::{Path, PathBuf}, +}; + +use crate::common::errors::{Error, Result}; + +/// File can be readed or writen only by opened. +pub struct File { + path: PathBuf, + file: Option, + readable: bool, +} + +impl File { + pub fn new>(path: P, readable: bool) -> Result { + let path = path.as_ref().to_path_buf(); + if path.is_dir() { + return Err(Error::InvalidPath(format!("{:?}", path))); + } + + Ok(File { + path, + file: None, + readable, + }) + } + + /// Create the dir if it does not exists + fn create_dir>(&self, dir: P) -> Result<()> { + if !dir.as_ref().exists() { + Ok(create_dir_all(dir)?) + } else { + Ok(()) + } + } + + /// Create or open the file + pub fn open(&mut self) -> Result<&mut Self> { + if let Some(dir) = self.path.parent() { + self.create_dir(dir)?; + } + let file = OpenOptions::new() + .read(self.readable) + .write(true) + .truncate(false) + .create(true) + .open(self.path.as_path())?; + self.file = Some(file); + Ok(self) + } + + pub fn exists(&self) -> bool { + self.path.as_path().exists() + } + + pub fn file_name(&self) -> Option<&str> { + if let Some(n) = self.path.as_path().file_name() { + n.to_str() + } else { + None + } + } + + pub fn file(&mut self) -> Result<&mut StdFile> { + if let Some(ref mut file) = self.file { + Ok(file) + } else { + Err(Error::Bug("`store::File::file` must be opened".to_string())) + } + } + + pub fn size(&self) -> u64 { + if let Ok(md) = metadata(&self.path) { + md.len() + } else { + 0 + } + } + + pub fn write(&mut self, buf: &[u8], seek: Option) -> Result { + if let Some(seek) = seek { + self.seek(seek)?; + } + Ok(self.file()?.write(buf)?) + } + + pub fn read(&mut self, buf: &mut [u8], seek: Option) -> Result { + if let Some(seek) = seek { + self.seek(seek)?; + } + Ok(self.file()?.read(buf)?) + } + + pub fn seek(&mut self, seek: SeekFrom) -> Result { + Ok(self.file()?.seek(seek)?) + } + + pub fn set_len(&mut self, size: u64) -> Result<()> { + Ok(self.file()?.set_len(size)?) + } + + pub fn remove(&self) -> Result<()> { + Ok(remove_file(self.path.as_path())?) + } +} diff --git a/src/common/liberal.rs b/src/common/liberal.rs new file mode 100644 index 0000000..5121a10 --- /dev/null +++ b/src/common/liberal.rs @@ -0,0 +1,54 @@ +use crate::common::errors::{Error, Result}; + +const SIZES: [&'static str; 5] = ["B", "K", "M", "G", "T"]; + +/// Convert liberal number to u64 +/// e.g. +/// 100k -> 100 * 1024 +pub trait ParseLiteralNumber { + fn literal_number(&self) -> Result; +} + +impl ParseLiteralNumber for &str { + fn literal_number(&self) -> Result { + let (num, unit) = self.split_at(self.len() - 1); + if unit.parse::().is_err() { + let mut num = num.parse::()?; + for s in &SIZES { + if s == &unit.to_uppercase() { + return Ok(num); + } else { + num *= 1024; + } + } + Ok(num) + } else { + let num = self.parse::()?; + Ok(num) + } + } +} + +/// Convert seconds to date format +pub trait ToDate { + fn date(&self) -> String; +} + +impl ToDate for u64 { + fn date(&self) -> String { + let mut num = *self as f64; + if num < 60.0 { + return format!("{:.0}s", num); + } + num /= 60.0; + if num < 60.0 { + return format!("{:.0}m", num); + } + num /= 60.0; + if num < 24.0 { + return format!("{:.0}h", num); + } + num /= 24.0; + return format!("{:.0}d", num); + } +} diff --git a/src/common/list.rs b/src/common/list.rs new file mode 100644 index 0000000..430a845 --- /dev/null +++ b/src/common/list.rs @@ -0,0 +1,30 @@ +use std::sync::{Arc, Mutex}; + +use crate::features::stack::StackLike; + +#[derive(Debug, Clone)] +pub struct SharedVec { + inner: Arc>>, +} + +impl SharedVec { + pub fn new(list: Vec) -> SharedVec { + SharedVec { + inner: Arc::new(Mutex::new(list)), + } + } +} + +impl StackLike for SharedVec { + fn push(&mut self, item: T) { + self.inner.lock().unwrap().push(item) + } + + fn pop(&mut self) -> Option { + self.inner.lock().unwrap().pop() + } + + fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..19e174b --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,18 @@ +#[macro_use] +pub mod debug; + +pub mod buf; +pub mod bytes; +pub mod character; +pub mod colors; +pub mod crypto; +pub mod errors; +pub mod file; +pub mod liberal; +pub mod list; +pub mod net; +pub mod range; +pub mod size; +pub mod tasks; +pub mod terminal; +pub mod uri; diff --git a/src/common/net/mod.rs b/src/common/net/mod.rs new file mode 100644 index 0000000..2567161 --- /dev/null +++ b/src/common/net/mod.rs @@ -0,0 +1,2 @@ +pub mod net; +pub mod net_type; diff --git a/src/common/net/net.rs b/src/common/net/net.rs new file mode 100644 index 0000000..572b9c5 --- /dev/null +++ b/src/common/net/net.rs @@ -0,0 +1,184 @@ +use std::time; + +use crate::common::{ + bytes::bytes_type::Bytes, + errors::{Error, Result}, + net::net_type::{ + header, Body, Configurable, ContentLengthValue, HttpClient, Method, Request, Response, Uri, + }, + range::RangePair, +}; + +pub fn parse_header(raw: &str) -> Result<(&str, &str), Error> { + if let Some(index) = raw.find(": ") { + return Ok((&raw[..index], &raw[index + 2..])); + } + if let Some(index) = raw.find(":") { + return Ok((&raw[..index], &raw[index + 1..])); + } + Err(Error::InvalidHeader(raw.to_string())) +} + +pub fn parse_headers<'a, I: IntoIterator>( + raws: I, +) -> Result, Error> { + let mut headers = vec![]; + for raw in raws { + let pair = parse_header(raw)?; + headers.push(pair); + } + Ok(headers) +} + +/// Builder a http client of curl +pub fn build_http_client( + headers: &[(&str, &str)], + timeout: u64, + proxy: Option<&str>, +) -> Result { + let mut builder = HttpClient::builder().default_headers(headers).proxy({ + if proxy.is_some() { + Some(proxy.unwrap().parse()?) + } else { + None + } + }); + // If timeout is zero, no timeout will be enforced. + if timeout > 0 { + builder = builder.timeout(time::Duration::from_secs(timeout)); + } + let client = builder.cookies().build()?; + Ok(client) +} + +// TODO +struct RequestInfo { + method: Method, + uri: String, + headers: Vec<(String, String)>, + // Post data + data: Option, + timeout: u64, + proxy: Option, +} + +impl RequestInfo {} + +/// Check whether the response is success +pub fn is_success(resp: &Response) -> Result<(), Error> { + let status = resp.status(); + if !status.is_success() { + Err(Error::Unsuccess(status.as_u16())) + } else { + Ok(()) + } +} + +/// Send a request with a range header, returning the final uri +pub async fn redirect( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, +) -> Result { + let mut uri = uri; + loop { + let data = data.clone().map(|d| Body::from_bytes(&d)); + let request = Request::builder() + .method(method.clone()) + .uri(uri.clone()) + .header(header::RANGE, "bytes=0-1") + .body(data)?; + let resp = client.send_async(request).await?; + // is_success(&resp)?; + if !resp.status().is_redirection() { + break; + } + let headers = resp.headers(); + if let Some(location) = headers.get(header::LOCATION) { + uri = location.to_str()?.parse()?; + } else { + break; + } + } + Ok(uri) +} + +/// Get the content length of the resource +pub async fn content_length( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, +) -> Result { + let mut uri = uri; + loop { + let data = data.clone().map(|d| Body::from_bytes(&d)); + let request = Request::builder() + .method(method.clone()) + .uri(uri.clone()) + .header(header::RANGE, "bytes=0-1") + .body(data)?; + let resp = client.send_async(request).await?; + is_success(&resp)?; + let headers = resp.headers(); + if resp.status().is_redirection() { + if let Some(location) = headers.get(header::LOCATION) { + uri = location.to_str()?.parse()?; + continue; + } else { + return Err(Error::NoLocation(format!("{}", uri))); + } + } else { + if let Some(h) = headers.get(header::CONTENT_RANGE) { + if let Ok(s) = h.to_str() { + if let Some(index) = s.find("/") { + if let Ok(length) = &s[index + 1..].parse::() { + return Ok(ContentLengthValue::RangeLength(length.clone())); + } + } + } + } + if let Some(h) = resp.headers().get(header::CONTENT_LENGTH) { + if let Ok(s) = h.to_str() { + if let Ok(length) = s.parse::() { + return Ok(ContentLengthValue::DirectLength(length.clone())); + } + } + } + break; + } + } + Ok(ContentLengthValue::NoLength) +} + +/// Send a request +pub async fn request( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, + range: Option, +) -> Result> { + let mut uri = uri; + loop { + let data = data.clone().map(|d| Body::from_bytes(&d)); + let mut builder = Request::builder().method(method.clone()).uri(uri.clone()); + + if let Some(RangePair { begin, end }) = range { + builder = builder.header(header::RANGE, &format!("bytes={}-{}", begin, end)); + } + let request = builder.body(data)?; + let resp = client.send_async(request).await?; + is_success(&resp)?; + if resp.status().is_redirection() { + if let Some(location) = resp.headers().get(header::LOCATION) { + uri = location.to_str()?.parse()?; + continue; + } else { + return Err(Error::NoLocation(format!("{}", uri))); + } + } + return Ok(resp); + } +} diff --git a/src/common/net/net_type.rs b/src/common/net/net_type.rs new file mode 100644 index 0000000..932cfd5 --- /dev/null +++ b/src/common/net/net_type.rs @@ -0,0 +1,16 @@ +pub use isahc::{ + self, + config::Configurable, + http, + http::{header, HeaderMap, HeaderValue, Method, Request, Response, Uri}, + Body, HttpClient, RequestExt, ResponseExt, +}; + +pub use url::Url; + +#[derive(Debug)] +pub enum ContentLengthValue { + RangeLength(u64), + DirectLength(u64), + NoLength, +} diff --git a/src/common/range.rs b/src/common/range.rs new file mode 100644 index 0000000..99970c5 --- /dev/null +++ b/src/common/range.rs @@ -0,0 +1,70 @@ +use std::sync::{Arc, Mutex}; + +use crate::features::stack::StackLike; + +#[derive(Debug, Clone, Copy)] +pub struct RangePair { + pub begin: u64, + pub end: u64, +} + +impl RangePair { + pub fn new(begin: u64, end: u64) -> RangePair { + RangePair { begin, end } + } + + // The length of a `RangePair` is the closed interval length + pub fn length(&self) -> u64 { + self.end - self.begin + 1 + } +} + +pub type RangeList = Vec; + +#[derive(Debug, Clone)] +pub struct SharedRangList { + inner: Arc>, +} + +impl SharedRangList { + pub fn new(rangelist: RangeList) -> SharedRangList { + SharedRangList { + inner: Arc::new(Mutex::new(rangelist)), + } + } +} + +impl StackLike for SharedRangList { + fn push(&mut self, pair: RangePair) { + self.inner.lock().unwrap().push(pair) + } + + fn pop(&mut self) -> Option { + self.inner.lock().unwrap().pop() + } + + fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } +} + +/// Split a close `RangePair` to many piece of pairs that each of their size is equal to +/// `chunk_size`, but the last piece size can be less then `chunk_size`. +pub fn split_pair(pair: &RangePair, chunk_size: u64) -> RangeList { + let mut stack = Vec::new(); + + let mut begin = pair.begin; + let interval_end = pair.end; + + while begin + chunk_size - 1 <= interval_end { + let end = begin + chunk_size - 1; + stack.push(RangePair::new(begin, end)); + begin += chunk_size; + } + + if begin <= interval_end { + stack.push(RangePair::new(begin, interval_end)); + } + + stack +} diff --git a/src/common/size.rs b/src/common/size.rs new file mode 100644 index 0000000..a427286 --- /dev/null +++ b/src/common/size.rs @@ -0,0 +1,34 @@ +/// Here, we handle about size. + +const SIZES: [&'static str; 5] = ["B", "K", "M", "G", "T"]; + +/// Convert number to human-readable +pub trait HumanReadable { + fn human_readable(&self) -> String; +} + +impl HumanReadable for u64 { + fn human_readable(&self) -> String { + let mut num = *self as f64; + for s in &SIZES { + if num < 1024.0 { + return format!("{:.1}{}", num, s); + } + num /= 1024.0; + } + return format!("{:.1}{}", num, SIZES[SIZES.len() - 1]); + } +} + +impl HumanReadable for f64 { + fn human_readable(&self) -> String { + let mut num = *self; + for s in &SIZES { + if num < 1024.0 { + return format!("{:.1}{}", num, s); + } + num /= 1024.0; + } + return format!("{:.1}{}", num, SIZES[SIZES.len() - 1]); + } +} diff --git a/src/common/tasks.rs b/src/common/tasks.rs new file mode 100644 index 0000000..33924d7 --- /dev/null +++ b/src/common/tasks.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone)] +pub enum TaskType { + HTTP, + M3U8, +} diff --git a/src/common/terminal.rs b/src/common/terminal.rs new file mode 100644 index 0000000..753b174 --- /dev/null +++ b/src/common/terminal.rs @@ -0,0 +1,13 @@ +use term_size::dimensions; + +const MIN_TERMINAL_WIDTH: u64 = 60; + +pub fn terminal_width() -> u64 { + if let Some((width, _)) = dimensions() { + width as u64 + } else { + // for envrionment in which atty is not available, + // example, at ci of osx + MIN_TERMINAL_WIDTH + } +} diff --git a/src/common/uri.rs b/src/common/uri.rs new file mode 100644 index 0000000..28f5c4d --- /dev/null +++ b/src/common/uri.rs @@ -0,0 +1,22 @@ +use std::path::Path; + +use crate::common::{ + errors::{Error, Result}, + net::net_type::Uri, +}; + +/// Use the last of components of uri as a file name +pub trait UriFileName { + fn file_name(&self) -> Result<&str>; +} + +impl UriFileName for Uri { + fn file_name(&self) -> Result<&str> { + let path = Path::new(self.path()); + if let Some(file_name) = path.file_name() { + Ok(file_name.to_str().unwrap()) + } else { + Err(Error::NoFilename) + } + } +} diff --git a/src/core.rs b/src/core.rs deleted file mode 100644 index 0d8919a..0000000 --- a/src/core.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::{cell::RefCell, cmp::min, io::SeekFrom, rc::Rc, time::Duration}; - -use futures::{ - channel::mpsc::{channel, Receiver}, - select, - stream::StreamExt, -}; - -use actix_rt::{spawn, time::interval, System}; -use bytes::Bytes; - -use crate::{ - app::Config, - chunk::{make_range_chunks, RangePart, RangeStack}, - error::{AgetError, NetError, Result}, - printer::Printer, - request::{get_content_length, get_redirect_uri, AgetRequestOptions, ContentLengthItem}, - store::{AgetFile, File, TaskInfo}, - task::RequestTask, - util::QUIET, -}; - -pub struct CoreProcess { - config: Config, - options: AgetRequestOptions, - range_stack: RangeStack, - // the length of range_stack - range_count: u64, -} - -impl CoreProcess { - pub fn new(config: Config) -> Result { - let headers = config - .headers - .iter() - .map(AsRef::as_ref) - .collect::>(); - let data = config.data.as_ref().map(AsRef::as_ref); - let options = - AgetRequestOptions::new(&config.uri, &config.method, &headers, data, config.timeout)?; - - Ok(CoreProcess { - config, - options, - range_stack: Rc::new(RefCell::new(vec![])), - range_count: 1, - }) - } - - fn check_content_length(&self, content_length: u64) -> Result<()> { - debug!("Check content length", content_length); - let mut aget_file = AgetFile::new(&self.config.path)?; - if aget_file.exists() { - aget_file.open()?; - if content_length != aget_file.content_length()? { - debug!( - "!! the content length that response returned isn't equal of aget file", - format!("{} != {}", content_length, aget_file.content_length()?) - ); - return Err(AgetError::ContentLengthIsNotConsistent.into()); - } - } - debug!("Check content length: equal"); - - Ok(()) - } - - fn set_content_length(&self, content_length: u64) -> Result<()> { - debug!("Set content length", content_length); - let mut aget_file = AgetFile::new(&self.config.path)?; - if !aget_file.exists() { - aget_file.open()?; - aget_file.write_content_length(content_length)?; - } else { - aget_file.open()?; - aget_file.rewrite()?; - } - Ok(()) - } - - fn make_range_stack(&mut self) -> Result<()> { - debug!("Make range stack"); - - let mut range_stack: Vec = Vec::new(); - - if self.options.is_concurrent() { - let mut aget_file = AgetFile::new(&self.config.path)?; - aget_file.open()?; - let gaps = aget_file.gaps()?; - - let chunk_length = self.config.chunk_length; - for gap in gaps.iter() { - let mut list = make_range_chunks(gap, chunk_length); - range_stack.append(&mut list); - } - range_stack.reverse(); - } else { - range_stack.push(RangePart::new(0, 0)); - }; - - self.range_count = range_stack.len() as u64; - - debug!("Range stack size", range_stack.len()); - self.range_stack = Rc::new(RefCell::new(range_stack)); - - Ok(()) - } - - pub async fn run(&mut self) -> Result<()> { - // 1. Get redirected uri - debug!("Redirect task"); - get_redirect_uri(&mut self.options).await?; - - // 2. Get content length - debug!("ContentLength task"); - let cn_item = get_content_length(&mut self.options).await?; - match cn_item { - ContentLengthItem::RangeLength(content_length) => { - self.check_content_length(content_length)?; - self.set_content_length(content_length)?; - } - ContentLengthItem::DirectLength(content_length) => { - self.set_content_length(content_length)?; - self.options.no_concurrency(); - - // Let connector to be always alive - self.options.reset_connector(60, 60, 0); - } - ContentLengthItem::NoLength => { - return Err(NetError::NoContentLength.into()); - } - } - self.make_range_stack()?; - - // 3. Spawn concurrent tasks - let is_concurrent = self.options.is_concurrent(); - let (sender, receiver) = - channel::<(RangePart, Bytes)>((self.config.concurrency + 1) as usize); - let concurrency = if is_concurrent { - self.config.concurrency - } else { - 1 - }; - debug!("Spawn RequestTasks", concurrency); - - for i in 0..(min(self.config.concurrency, self.range_count)) { - debug!("RequestTask ", i); - let range_stack = self.range_stack.clone(); - let sender_ = sender.clone(); - let options = self.options.clone(); - let task = async { - let is_concurrent = options.is_concurrent(); - let mut request_task = RequestTask::new(range_stack, sender_); - let result = request_task.run(options).await; - if let Err(err) = result { - print_err!("RequestTask fails", err); - if !is_concurrent { - // Exit process when the only one request task fails - System::current().stop(); - } - } - }; - spawn(task); - } - - // 4. Wait stream handler - debug!("Start StreamHander"); - let concurrency = if is_concurrent { - self.config.concurrency - } else { - 0 - }; - let stream_header = StreamHander::new(&self.config.path, !is_concurrent, concurrency); - if stream_header.is_err() { - System::current().stop(); - } - let stream_header = stream_header.unwrap(); - stream_header.run(receiver).await; - - debug!("CoreProcess done"); - Ok(()) - } -} - -struct StreamHander { - file: File, - aget_file: AgetFile, - task_info: TaskInfo, - printer: Printer, - no_record: bool, - concurrency: u64, -} - -impl StreamHander { - fn new(path: &str, no_record: bool, concurrency: u64) -> Result { - let task_info = TaskInfo::new(path)?; - - let mut file = File::new(path, false)?; - file.open()?; - let mut aget_file = AgetFile::new(path)?; - aget_file.open()?; - - let printer = Printer::new(); - let mut handler = StreamHander { - file, - aget_file, - task_info, - printer, - no_record, - concurrency, - }; - handler.init_print()?; - Ok(handler) - } - - fn init_print(&mut self) -> Result<(), AgetError> { - unsafe { - if QUIET { - return Ok(()); - } - } - - if self.no_record { - self.printer - .print_msg("Server doesn't support range request.")?; - } - - let file_name = &self.task_info.path; - let content_length = self.task_info.content_length; - let concurrency = self.concurrency; - self.printer.print_header(file_name)?; - self.printer.print_length(content_length)?; - self.printer.print_concurrency(concurrency)?; - self.print_process()?; - Ok(()) - } - - fn print_process(&mut self) -> Result<(), AgetError> { - unsafe { - if QUIET { - return Ok(()); - } - } - - let total_length = self.task_info.content_length; - let completed_length = self.task_info.completed_length(); - let (rate, eta) = self.task_info.rate_and_eta(); - self.printer - .print_process(completed_length, total_length, rate, eta)?; - Ok(()) - } - - fn record_range(&mut self, range_part: RangePart) -> Result<(), ()> { - if self.no_record { - return Ok(()); - } - if let Err(err) = self.aget_file.write_interval(range_part) { - print_err!("write interval to aget file fails", err); - return Err(()); - } - return Ok(()); - } - - fn teardown(&mut self) -> Result<(), AgetError> { - self.aget_file.remove()?; - Ok(()) - } - - pub async fn run(mut self, receiver: Receiver<(RangePart, Bytes)>) { - debug!("StreamHander run"); - let mut tick = interval(Duration::from_secs(1)).fuse(); - let mut receiver = receiver.fuse(); - loop { - select! { - item = receiver.next() => { - if let Some((range, chunk)) = item { - let interval_length = range.length(); - - // write buf - if let Err(err) = self - .file - .write(&chunk[..], Some(SeekFrom::Start(range.start))) - { - print_err!("write chunk to file fails", err); - } - - // write range_part - self.record_range(range); - - // update `task_info` - self.task_info.add_completed(interval_length); - } else { - break; - } - } - _ = tick.next() => { - if let Err(err) = self.print_process() { - print_err!("print process fails", err); - } - self.task_info.clean_interval(); - if self.task_info.remains() == 0 { - if let Err(err) = self.print_process() { - print_err!("print process fails", err); - } - if let Err(err) = self.teardown() { - print_err!("teardown stream handler fails", err); - } - break; - } - } - } - } - } -} diff --git a/src/error.rs b/src/error.rs deleted file mode 100644 index 072fc4a..0000000 --- a/src/error.rs +++ /dev/null @@ -1,182 +0,0 @@ -use std::{fmt, io::Error as IoError, num, result}; - -use failure::{self, Backtrace, Fail}; -use futures::channel::mpsc::SendError; - -use awc::{ - self, - error::SendRequestError, - http, - http::{ - header::{InvalidHeaderName, InvalidHeaderValue, ToStrError}, - uri::InvalidUri, - }, -}; - -pub type Result = result::Result; - -pub struct Error { - cause: Box, - backtrace: Option, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.cause, f) - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(bt) = self.cause.backtrace() { - write!(f, "{:?}\n\n{:?}", &self.cause, bt) - } else { - write!( - f, - "{:?}\n\n{:?}", - &self.cause, - self.backtrace.as_ref().unwrap() - ) - } - } -} - -pub trait AgetFail: Fail {} - -impl From for Error { - fn from(err: T) -> Error { - let backtrace = if err.backtrace().is_none() { - Some(Backtrace::new()) - } else { - None - }; - Error { - cause: Box::new(err), - backtrace, - } - } -} - -#[derive(Fail, Debug)] -pub enum ArgError { - #[fail(display = "Output path is invalid: {}", _0)] - InvalidPath(String), - #[fail(display = "the uri is invalid: {}", _0)] - InvaildUri(String), - #[fail(display = "No filename.")] - NoFilename, - #[fail(display = "Directory is not found")] - NotFoundDirectory, - #[fail(display = "The file already exists.")] - FileExists, - #[fail(display = "The path is a directory.")] - PathIsDirectory, - #[fail(display = "Can't parse string as number: {}", _0)] - IsNotNumber(String), - #[fail(display = "Io Error: {}", _0)] - Io(#[cause] IoError), -} - -impl AgetFail for ArgError {} - -impl From for ArgError { - fn from(err: http::uri::InvalidUri) -> ArgError { - ArgError::InvaildUri(format!("{}", err)) - } -} - -impl From for ArgError { - fn from(err: num::ParseIntError) -> ArgError { - ArgError::IsNotNumber(format!("{}", err)) - } -} - -impl From for ArgError { - fn from(err: IoError) -> ArgError { - ArgError::Io(err) - } -} - -#[derive(Fail, Debug)] -pub enum AgetError { - #[fail(display = "Method is unsupported")] - UnsupportedMethod, - #[fail(display = "header is invalid: {}", _0)] - HeaderParseError(String), - #[fail(display = "BUG: {}", _0)] - Bug(String), - #[fail(display = "Path is invalid: {}", _0)] - InvalidPath(String), - #[fail(display = "Io Error: {}", _0)] - Io(#[cause] IoError), - #[fail(display = "No filename.")] - NoFilename, - #[fail(display = "The file already exists.")] - FileExists, - #[fail( - display = "The two content lengths are not equal between the response and the aget file." - )] - ContentLengthIsNotConsistent, -} - -impl AgetFail for AgetError {} - -impl From for AgetError { - fn from(err: IoError) -> AgetError { - AgetError::Io(err) - } -} - -#[derive(Fail, Debug)] -pub enum NetError { - #[fail(display = "an internal error: {}", _0)] - ActixError(String), - #[fail(display = "content does not has length")] - NoContentLength, - #[fail(display = "uri is invalid: {}", _0)] - InvaildUri(String), - #[fail(display = "header is invalid: {}", _0)] - InvaildHeader(String), - #[fail(display = "response status code is: {}", _0)] - Unsuccess(u16), - #[fail(display = "Redirect to: {}", _0)] - Redirect(String), -} - -impl AgetFail for NetError {} - -impl From for NetError { - fn from(err: SendRequestError) -> NetError { - NetError::ActixError(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: ToStrError) -> NetError { - NetError::ActixError(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: InvalidUri) -> NetError { - NetError::InvaildUri(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: InvalidHeaderName) -> NetError { - NetError::InvaildHeader(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: InvalidHeaderValue) -> NetError { - NetError::InvaildHeader(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: SendError) -> NetError { - NetError::ActixError(format!("{}", err)) - } -} diff --git a/src/features/args.rs b/src/features/args.rs new file mode 100644 index 0000000..9d1f689 --- /dev/null +++ b/src/features/args.rs @@ -0,0 +1,52 @@ +use std::path::PathBuf; + +use crate::common::{ + bytes::bytes_type::BytesMut, + net::net_type::{Method, Uri}, + tasks::TaskType, +}; + +/// This a arg which gives parameters for apps +pub trait Args { + /// Path of output + fn output(&self) -> PathBuf; + + /// Request method for http + fn method(&self) -> Method; + + /// The uri of a task + fn uri(&self) -> Uri; + + /// The data for http post request + fn data(&self) -> Option; + + /// Request headers + fn headers(&self) -> Vec<(String, String)>; + + /// Proxy: http, https, socks4, socks5 + fn proxy(&self) -> Option; + + /// The maximum time the request is allowed to take. + fn timeout(&self) -> u64; + + /// The number of concurrency + fn concurrency(&self) -> u64; + + /// The chunk length of each concurrency for http task + fn chunk_length(&self) -> u64; + + /// The number of retry of a task + fn retries(&self) -> u64; + + /// The internal of each retry + fn retry_wait(&self) -> u64; + + /// Task type + fn task_type(&self) -> TaskType; + + /// To debug mode, if it returns true + fn debug(&self) -> bool; + + /// To quiet mode, if it return true + fn quiet(&self) -> bool; +} diff --git a/src/features/mod.rs b/src/features/mod.rs new file mode 100644 index 0000000..2bed12f --- /dev/null +++ b/src/features/mod.rs @@ -0,0 +1,3 @@ +pub mod args; +pub mod running; +pub mod stack; diff --git a/src/features/running.rs b/src/features/running.rs new file mode 100644 index 0000000..f880277 --- /dev/null +++ b/src/features/running.rs @@ -0,0 +1,5 @@ +use crate::common::errors::Result; + +pub trait Runnable { + fn run(&mut self) -> Result<()>; +} diff --git a/src/features/stack.rs b/src/features/stack.rs new file mode 100644 index 0000000..c7acca7 --- /dev/null +++ b/src/features/stack.rs @@ -0,0 +1,7 @@ +pub trait StackLike { + fn push(&mut self, item: T); + + fn pop(&mut self) -> Option; + + fn len(&self) -> usize; +} diff --git a/src/main.rs b/src/main.rs index 3847294..8612fed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,93 +1,68 @@ -#![allow(unused_variables)] #![allow(dead_code)] -#![recursion_limit = "256"] - -use std::{process::exit, thread, time}; - -use actix_rt::{spawn, System}; #[macro_use] -mod util; +mod common; mod app; -mod chunk; -mod clap_app; -mod common; -mod core; -mod error; -mod printer; -mod request; -mod store; -mod task; +mod arguments; +mod features; -use crate::{ - app::App, - core::CoreProcess, - util::{DEBUG, QUIET}, -}; +use std::{thread, time::Duration}; -static mut SUCCESS: bool = false; +use app::core::{http::HttpHandler, m3u8::m3u8::M3u8Handler}; +use arguments::cmd_args::CmdArgs; +use common::{ + debug::{DEBUG, QUIET}, + tasks::TaskType, +}; +use features::{args::Args, running::Runnable}; fn main() { - let app = App::new(); - match app.config() { - Ok(config) => { - // set verbose - unsafe { - DEBUG = config.debug; - QUIET = config.quiet; - } - - debug!("Input configuration", &config); + let cmdargs = CmdArgs::new(); - let retry_wait = config.retry_wait; - let max_retries = config.max_retries; - for i in 0..(max_retries + 1) { - if i > 0 { - print_err!("!!! Retry", i); - thread::sleep(time::Duration::from_secs(retry_wait)); - } + // Set debug + if cmdargs.debug() { + unsafe { + DEBUG = true; + } + debug!("Args", cmdargs); + } - let sys = System::new("Aget"); + // Set quiet + if cmdargs.quiet() { + unsafe { + QUIET = true; + } + } - debug!("Make CoreProcess task"); + debug!("Main: begin"); - let config_ = config.clone(); - spawn(async move { - let core_process = CoreProcess::new(config_); - if let Ok(mut core_fut) = core_process { - let result = core_fut.run().await; - if let Err(err) = result { - print_err!("core_fut fails", err); - } else { - unsafe { - SUCCESS = true; - } - } - } else { - print_err!("core_fut error", ""); - exit(1); - } - debug!("main done"); - System::current().stop(); - }); + let tasktype = cmdargs.task_type(); + for i in 0..cmdargs.retries() + 1 { + if i != 0 { + println!("Retry {}", i); + } - if let Err(err) = sys.run() { - print_err!("System error", err); - } else { - // check task state - unsafe { - if SUCCESS { - break; - } - } - } + let result = match tasktype { + TaskType::HTTP => { + let mut httphandler = HttpHandler::new(&cmdargs).unwrap(); + httphandler.run() } - // debug!("!!! Can't be here"); - } - Err(err) => { - print_err!("app config fails", err); - exit(1); + TaskType::M3U8 => { + let mut m3u8handler = M3u8Handler::new(&cmdargs).unwrap(); + m3u8handler.run() + } + }; + + if let Err(err) = result { + print_err!("Error", err); + // Retry + let retrywait = cmdargs.retry_wait(); + thread::sleep(Duration::from_secs(retrywait)); + continue; + } else { + // Success + break; } } } diff --git a/src/printer.rs b/src/printer.rs deleted file mode 100644 index 29cf39f..0000000 --- a/src/printer.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::io::{stdout, Stdout, Write}; - -use ansi_term::{ - Colour::{Blue, Cyan, Green, Red, Yellow}, - Style, -}; - -use crate::{ - error::{AgetError, Result}, - util::{terminal_width, SizeOfFmt, TimeOfFmt}, -}; - -pub struct Printer { - colors: Colors, - terminal_width: u64, - stdout: Stdout, -} - -impl Printer { - pub fn new() -> Printer { - let terminal_width = terminal_width(); - Printer { - colors: Colors::colored(), - terminal_width, - stdout: stdout(), - } - } - - pub fn print_msg(&mut self, msg: &str) -> Result<(), AgetError> { - writeln!(&mut self.stdout, "\n {}", self.colors.msg.paint(msg))?; - Ok(()) - } - - pub fn print_header(&mut self, path: &str) -> Result<(), AgetError> { - writeln!( - &mut self.stdout, - "\n {}: {}", - self.colors.file_header.paint(" File"), - path, - )?; - Ok(()) - } - - pub fn print_length(&mut self, content_length: u64) -> Result<(), AgetError> { - writeln!( - &mut self.stdout, - " {}: {} ({})", - self.colors.content_length_header.paint("Length"), - content_length.sizeof_fmt(), - content_length, - )?; - Ok(()) - } - - pub fn print_concurrency(&mut self, concurrency: u64) -> Result<(), AgetError> { - writeln!( - &mut self.stdout, - "{}: {}\n", - self.colors.concurrency_header.paint("concurrency"), - concurrency, - )?; - Ok(()) - } - - pub fn print_process( - &mut self, - completed_length: u64, - total_length: u64, - rate: f64, - eta: u64, - ) -> Result<(), AgetError> { - let percent = completed_length as f64 / total_length as f64; - - let completed_length_str = completed_length.sizeof_fmt(); - let total_length_str = total_length.sizeof_fmt(); - let percent_str = format!("{:.2}", percent * 100.0); - let rate_str = rate.sizeof_fmt(); - let eta_str = eta.timeof_fmt(); - - // maximum info length is 41 e.g. - // 1001.3k/1021.9m 97.98% 1003.1B/s eta: 12s - let info = format!( - "{completed_length}/{total_length} {percent}% {rate}/s eta: {eta}", - completed_length = completed_length_str, - total_length = total_length_str, - percent = percent_str, - rate = rate_str, - eta = eta_str, - ); - - // set default info length - let info_length = 41; - let miss = info_length - info.len(); - - let bar_length = self.terminal_width - info_length as u64 - 4; - let process_bar_length = (bar_length as f64 * percent) as u64; - let blank_length = bar_length - process_bar_length; - - let process_bar_str = if process_bar_length > 0 { - format!("{}>", "=".repeat((process_bar_length - 1) as usize)) - } else { - "".to_owned() - }; - let blank_str = " ".repeat(blank_length as usize); - - write!( - &mut self.stdout, - "\r{completed_length}/{total_length} {percent}% {rate}/s eta: {eta}{miss} [{process_bar}{blank}] ", - completed_length = self.colors.completed_length.paint(completed_length_str), - total_length = self.colors.total_length.paint(total_length_str), - percent = self.colors.percent.paint(percent_str), - rate = self.colors.rate.paint(rate_str), - eta = self.colors.eta.paint(eta_str), - miss = " ".repeat(miss), - process_bar = process_bar_str, - blank = blank_str, - )?; - - self.stdout.flush()?; - - Ok(()) - } -} - -#[derive(Default)] -pub struct Colors { - pub file_header: Style, - pub content_length_header: Style, - pub concurrency_header: Style, - pub completed_length: Style, - pub total_length: Style, - pub percent: Style, - pub rate: Style, - pub eta: Style, - pub msg: Style, -} - -impl Colors { - pub fn plain() -> Colors { - Colors::default() - } - - pub fn colored() -> Colors { - Colors { - file_header: Green.bold(), - content_length_header: Blue.bold(), - completed_length: Red.bold(), - concurrency_header: Yellow.bold(), - total_length: Green.bold(), - percent: Yellow.bold(), - rate: Blue.bold(), - eta: Cyan.bold(), - msg: Yellow.italic(), - } - } -} diff --git a/src/request.rs b/src/request.rs deleted file mode 100644 index e9d70d4..0000000 --- a/src/request.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::time::Duration; - -use awc::{ - http::{header, Method, Uri}, - Client, ClientBuilder, ClientRequest, Connector, -}; - -use clap::crate_version; - -use crate::error::{AgetError, NetError, Result}; - -fn parse_header(raw: &str) -> Result<(&str, &str), AgetError> { - if let Some(index) = raw.find(": ") { - return Ok((&raw[..index], &raw[index + 2..])); - } - if let Some(index) = raw.find(":") { - return Ok((&raw[..index], &raw[index + 1..])); - } - Err(AgetError::HeaderParseError(raw.to_string())) -} - -#[derive(Clone)] -pub struct AgetRequestOptions { - uri: String, - method: Method, - headers: Vec<(String, String)>, - body: Option, - concurrent: bool, - client: Client, -} - -impl AgetRequestOptions { - pub fn new( - uri: &str, - method: &str, - headers: &[&str], - body: Option<&str>, - timeout: u64, - ) -> Result { - let method = match method.to_uppercase().as_str() { - "GET" => Method::GET, - "POST" => Method::POST, - _ => return Err(AgetError::UnsupportedMethod), - }; - - let mut header_list = Vec::new(); - for header in headers.iter() { - let (key, value) = parse_header(header)?; - header_list.push((key.to_string(), value.to_string())); - } - - let connector = Connector::new() - .limit(0) // no limit simultaneous connections. - .timeout(Duration::from_secs(timeout)) // DNS timeout - .conn_keep_alive(Duration::from_secs(60)) - .conn_lifetime(Duration::from_secs(0)) - .finish(); - let client = ClientBuilder::new().connector(connector).finish(); - - Ok(AgetRequestOptions { - method, - uri: uri.to_string(), - headers: header_list, - body: if let Some(body) = body { - Some(body.to_string()) - } else { - None - }, - concurrent: true, - client, - }) - } - - pub fn build(&self) -> Result { - // set user-agent if none - let aget_ua = format!("aget/{}", crate_version!()); - - let uri = self.uri.parse::()?; - let host = if let Some(host) = uri.host() { - host - } else { - return Err(NetError::InvaildUri(self.uri.to_string())); - }; - - let mut client_request = self - .client - .request(self.method.clone(), self.uri.clone()) - .set_header_if_none("User-Agent", aget_ua) // set default user-agent - .set_header_if_none("Accept", "*/*") // set accept if none - .set_header_if_none("Host", host); // set header `Host` - - for (ref key, ref val) in &self.headers { - client_request - .headers_mut() - .insert(key.as_str().parse()?, val.as_str().parse()?); - } - - Ok(client_request) - } - - pub fn uri(&self) -> String { - self.uri.clone() - } - - pub fn set_uri(&mut self, uri: &str) -> &mut Self { - self.uri = uri.to_string(); - self - } - - pub fn body(&self) -> &Option { - &self.body - } - - pub fn is_concurrent(&self) -> bool { - self.concurrent - } - - pub fn no_concurrency(&mut self) -> &mut Self { - self.concurrent = false; - self - } - - pub fn reset_connector(&mut self, timeout: u64, keep_alive: u64, lifetime: u64) -> &mut Self { - let connector = Connector::new() - .limit(0) // no limit simultaneous connections. - .timeout(Duration::from_secs(timeout)) // DNS timeout - .conn_keep_alive(Duration::from_secs(keep_alive)) - .conn_lifetime(Duration::from_secs(lifetime)) - .finish(); - self.client = ClientBuilder::new().connector(connector).finish(); - self - } -} - -/// Get redirected uri and reset `AgetRequestOptions.uri` -pub async fn get_redirect_uri(options: &mut AgetRequestOptions) -> Result<(), NetError> { - loop { - let client_request = options.build()?; - let resp = if let Some(body) = options.body() { - client_request.send_body(body).await? - } else { - client_request.send().await? - }; - let status = resp.status(); - if !(status.is_success() || status.is_redirection()) { - return Err(NetError::Unsuccess(status.as_u16())); - } else { - if status.is_redirection() { - if let Some(location) = resp.headers().get(header::LOCATION) { - options.set_uri(location.to_str()?); - } - } - return Ok(()); - } - } -} - -#[derive(Debug)] -pub enum ContentLengthItem { - RangeLength(u64), - DirectLength(u64), - NoLength, -} - -pub async fn get_content_length( - options: &mut AgetRequestOptions, -) -> Result { - let client_request = options.build()?.header(header::RANGE, "bytes=0-1"); - let resp = if let Some(body) = options.body() { - client_request.send_body(body).await? - } else { - client_request.send().await? - }; - - let status = resp.status(); - if !status.is_success() { - return Err(NetError::Unsuccess(status.as_u16())); - } else { - if let Some(h) = resp.headers().get(header::CONTENT_RANGE) { - if let Ok(s) = h.to_str() { - if let Some(index) = s.find("/") { - if let Ok(length) = &s[index + 1..].parse::() { - return Ok(ContentLengthItem::RangeLength(length.clone())); - } - } - } - } else { - if let Some(h) = resp.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = h.to_str() { - if let Ok(length) = s.parse::() { - return Ok(ContentLengthItem::DirectLength(length.clone())); - } - } - } - } - } - Ok(ContentLengthItem::NoLength) -} diff --git a/src/store.rs b/src/store.rs deleted file mode 100644 index e19c24c..0000000 --- a/src/store.rs +++ /dev/null @@ -1,362 +0,0 @@ -use std::{ - cmp::max, - fs::{remove_file, File as StdFile, OpenOptions}, - io::{Read, Seek, SeekFrom, Write}, - path::PathBuf, - time::Instant, -}; - -use crate::{ - chunk::RangePart, - common::AGET_EXT, - error::{AgetError, Result}, -}; - -pub struct TaskInfo { - pub path: String, - - /// The length of the file - pub content_length: u64, - - /// The length stored to the file - completed_length: u64, - - /// The stored length at an interval of one tick - interval_length: u64, - - /// The interval of one tick - tick_interval: Instant, -} - -impl TaskInfo { - pub fn new(path: &str) -> Result { - let mut aget_file = AgetFile::new(path)?; - aget_file.open()?; - - let path = path.to_string(); - let content_length = aget_file.content_length()?; - let completed_length = aget_file.completed_length()?; - - Ok(TaskInfo { - path, - content_length, - completed_length, - interval_length: 0, - tick_interval: Instant::now(), - }) - } - - pub fn completed_length(&self) -> u64 { - self.completed_length - } - - pub fn remains(&self) -> u64 { - self.content_length - self.completed_length - } - - pub fn rate_and_eta(&self) -> (f64, u64) { - let interval = self.tick_interval.elapsed().as_secs_f64(); - let rate = self.interval_length as f64 / interval; - let remains = self.remains(); - // rate > 1.0 for overflow - let eta = if remains > 0 && rate > 1.0 { - let eta = (remains as f64 / rate) as u64; - // eta is large than 99 days, return 0 - if eta > 99 * 24 * 60 * 60 { - 0 - } else { - eta - } - } else { - 0 - }; - (rate, eta) - } - - pub fn add_completed(&mut self, interval_length: u64) { - self.completed_length += interval_length; - self.interval_length += interval_length; - } - - pub fn clean_interval(&mut self) { - self.interval_length = 0; - self.tick_interval = Instant::now(); - } -} - -pub struct File { - path: PathBuf, - file: Option, - readable: bool, -} - -impl File { - pub fn new(p: &str, readable: bool) -> Result { - let path = PathBuf::from(p); - if path.is_dir() { - return Err(AgetError::InvalidPath(p.to_string())); - } - - Ok(File { - path, - file: None, - readable, - }) - } - - pub fn open(&mut self) -> Result<&mut Self, AgetError> { - let file = OpenOptions::new() - .read(self.readable) - .write(true) - .truncate(false) - .create(true) - .open(self.path.as_path())?; - self.file = Some(file); - Ok(self) - } - - pub fn exists(&self) -> bool { - self.path.as_path().exists() - } - - pub fn file_name(&self) -> Result { - if let Some(file_name) = self.path.as_path().file_name() { - Ok(file_name.to_str().unwrap().to_string()) - } else { - Err(AgetError::NoFilename) - } - } - - pub fn file(&mut self) -> Result<&mut StdFile, AgetError> { - if let Some(ref mut file) = self.file { - Ok(file) - } else { - Err(AgetError::Bug( - "`store::File::file` must be opened".to_string(), - )) - } - } - - pub fn write(&mut self, buf: &[u8], seek: Option) -> Result { - if let Some(seek) = seek { - self.seek(seek)?; - } - let s = self.file()?.write(buf)?; - Ok(s) - } - - pub fn read(&mut self, buf: &mut [u8], seek: Option) -> Result { - if let Some(seek) = seek { - self.seek(seek)?; - } - let s = self.file()?.read(buf)?; - Ok(s) - } - - pub fn seek(&mut self, seek: SeekFrom) -> Result { - let s = self.file()?.seek(seek)?; - Ok(s) - } - - pub fn set_len(&mut self, size: u64) -> Result<(), AgetError> { - self.file()?.set_len(size)?; - Ok(()) - } - - pub fn remove(&self) -> Result<(), AgetError> { - remove_file(self.path.as_path())?; - Ok(()) - } -} - -/// Aget Information Store File -/// -/// The file stores two kinds of information which of the downloading file. -/// 1. Content Length -/// The downloading file's content length. -/// If the number `request::ContentLength` returned is not equal to this content length, -/// the process will be terminated. -/// 2. Close Intervals of Downloaded Pieces -/// These intervals are pieces `(u64, u64)` stored as big-endian. -/// First item is the begin of header `Range`. -/// Second item is the end of header `Range`. -pub struct AgetFile { - inner: File, -} - -impl AgetFile { - pub fn new(path: &str) -> Result { - let mut path = path.to_string(); - path.push_str(AGET_EXT); - - let file = File::new(&path, true)?; - Ok(AgetFile { inner: file }) - } - - pub fn open(&mut self) -> Result<&mut Self, AgetError> { - self.inner.open()?; - Ok(self) - } - - pub fn file_name(&self) -> Result { - self.inner.file_name() - } - - pub fn exists(&self) -> bool { - self.inner.exists() - } - - pub fn remove(&self) -> Result<(), AgetError> { - self.inner.remove() - } - - /// Get downloading file's content length stored in the aget file - pub fn content_length(&mut self) -> Result { - let mut buf: [u8; 8] = [0; 8]; - self.inner.read(&mut buf, Some(SeekFrom::Start(0)))?; - let content_length = u8x8_to_u64(&buf); - Ok(content_length) - } - - pub fn completed_length(&mut self) -> Result { - let completed_intervals = self.completed_intervals()?; - if completed_intervals.is_empty() { - Ok(0) - } else { - Ok(completed_intervals.iter().map(RangePart::length).sum()) - } - } - - pub fn completed_intervals(&mut self) -> Result, AgetError> { - let mut intervals: Vec<(u64, u64)> = Vec::new(); - - let mut buf: [u8; 16] = [0; 16]; - self.inner.seek(SeekFrom::Start(8))?; - loop { - let s = self.inner.read(&mut buf, None)?; - if s != 16 { - break; - } - - let mut raw = [0; 8]; - raw.clone_from_slice(&buf[..8]); - let start = u8x8_to_u64(&raw); - raw.clone_from_slice(&buf[8..]); - let end = u8x8_to_u64(&raw); - - assert!( - start <= end, - format!( - "Bug: `start > end` in an interval of aget file. : {} > {}", - start, end - ) - ); - - intervals.push((start, end)); - } - - intervals.sort(); - - // merge intervals - let mut merge_intervals: Vec<(u64, u64)> = Vec::new(); - if !intervals.is_empty() { - merge_intervals.push(intervals[0]); - } - for (start, end) in intervals.iter() { - let (pre_start, pre_end) = merge_intervals.last().unwrap().clone(); - - // case 1 - // ---------- - // ----------- - if pre_end + 1 < *start { - merge_intervals.push((*start, *end)); - // case 2 - // ----------------- - // ---------- - // -------- - // ------ - } else { - let n_start = pre_start; - let n_end = max(pre_end, *end); - merge_intervals.pop(); - merge_intervals.push((n_start, n_end)); - } - } - - Ok(merge_intervals - .iter() - .map(|(start, end)| RangePart::new(*start, *end)) - .collect::>()) - } - - /// Get gaps of undownload pieces - pub fn gaps(&mut self) -> Result, AgetError> { - let mut completed_intervals = self.completed_intervals()?; - let content_length = self.content_length()?; - completed_intervals.push(RangePart::new(content_length, content_length)); - - // find gaps - let mut gaps: Vec = Vec::new(); - // find first chunk - let RangePart { start, .. } = completed_intervals[0]; - if start > 0 { - gaps.push(RangePart::new(0, start - 1)); - } - - for (index, RangePart { end, .. }) in completed_intervals.iter().enumerate() { - if let Some(RangePart { - start: next_start, .. - }) = completed_intervals.get(index + 1) - { - if end + 1 < *next_start { - gaps.push(RangePart::new(end + 1, next_start - 1)); - } - } - } - - Ok(gaps) - } - - pub fn write_content_length(&mut self, content_length: u64) -> Result<(), AgetError> { - let buf = u64_to_u8x8(content_length); - self.inner.write(&buf, Some(SeekFrom::Start(0)))?; - Ok(()) - } - - pub fn write_interval(&mut self, interval: RangePart) -> Result<(), AgetError> { - let start = u64_to_u8x8(interval.start); - let end = u64_to_u8x8(interval.end); - let buf = [start, end].concat(); - self.inner.write(&buf, Some(SeekFrom::End(0)))?; - Ok(()) - } - - // Merge completed intervals and rewrite the aget file - pub fn rewrite(&mut self) -> Result<(), AgetError> { - let content_length = self.content_length()?; - let completed_intervals = self.completed_intervals()?; - - let mut buf: Vec = Vec::new(); - buf.extend(&u64_to_u8x8(content_length)); - for interval in completed_intervals.iter() { - buf.extend(&u64_to_u8x8(interval.start)); - buf.extend(&u64_to_u8x8(interval.end)); - } - - self.inner.set_len(0)?; - self.inner.write(buf.as_slice(), Some(SeekFrom::Start(0)))?; - - Ok(()) - } -} - -/// Create an integer value from its representation as a byte array in big endian. -pub fn u8x8_to_u64(u8x8: &[u8; 8]) -> u64 { - u64::from_be_bytes(*u8x8) -} - -/// Return the memory representation of this integer as a byte array in big-endian (network) byte -/// order. -pub fn u64_to_u8x8(u: u64) -> [u8; 8] { - u.to_be_bytes() -} diff --git a/src/task.rs b/src/task.rs deleted file mode 100644 index a517aa4..0000000 --- a/src/task.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::time::Duration; - -use awc::http::header; - -use futures::{channel::mpsc::Sender, stream::StreamExt}; -use futures_util::sink::SinkExt; - -use bytes::Bytes; - -use crate::{ - chunk::{RangePart, RangeStack}, - error::NetError, - request::AgetRequestOptions, -}; - -pub struct RequestTask { - range_stack: RangeStack, - sender: Sender<(RangePart, Bytes)>, -} - -impl RequestTask { - pub fn new(range_stack: RangeStack, sender: Sender<(RangePart, Bytes)>) -> RequestTask { - RequestTask { - range_stack, - sender, - } - } - - fn pop_range(&mut self) -> Option { - let mut stack = self.range_stack.borrow_mut(); - (*stack).pop() - } - - fn push_range(&mut self, range: RangePart) { - let mut stack = self.range_stack.borrow_mut(); - (*stack).push(range); - } - - pub async fn run(&mut self, mut options: AgetRequestOptions) -> Result<(), NetError> { - while let Some(range) = self.pop_range() { - let timeout = if options.is_concurrent() { - Duration::from_secs(60) - } else { - Duration::from_secs(10 * 24 * 60 * 60) // 10 days - }; - - let mut client_request = options.build()?.timeout(timeout); - if options.is_concurrent() { - client_request.headers_mut().insert( - header::RANGE, - format!("bytes={}-{}", range.start, range.end).parse()?, - ); - } - let resp = if let Some(body) = options.body() { - client_request.send_body(body).await - } else { - client_request.send().await - }; - - if resp.is_err() { - self.push_range(range); - continue; - } - - let mut resp = resp.unwrap(); - - let status = resp.status(); - - // handle redirect - if status.is_redirection() { - if let Some(location) = resp.headers().get(header::LOCATION) { - options.set_uri(location.to_str()?); - } - self.push_range(range); - continue; - } - - if !status.is_success() { - debug!("request error", status); - self.push_range(range); - continue; - } - - let mut range = range.clone(); - while let Some(chunk) = resp.next().await { - if let Ok(chunk) = chunk { - let len = chunk.len() as u64; - let start = range.start; - let end = start + len; - range.start = end; - - let mut sender = self.sender.clone(); - // the sended RangePart is a close interval as header `Range` - sender.send((RangePart::new(start, end - 1), chunk)).await?; - } else { - self.push_range(range); - break; - } - } - } - Ok(()) - } -} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 6333ed0..0000000 --- a/src/util.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::path::Path; - -use awc::http::Uri; - -use term_size::dimensions; - -use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; - -use crate::error::{AgetError, ArgError, Result}; - -pub trait FindFilename { - fn find_file_name(&self) -> Result<&str, AgetError>; -} - -impl FindFilename for Uri { - fn find_file_name(&self) -> Result<&str, AgetError> { - let path = Path::new(self.path()); - if let Some(file_name) = path.file_name() { - Ok(file_name.to_str().unwrap()) - } else { - Err(AgetError::NoFilename) - } - } -} - -const MIN_TERMINAL_WIDTH: u64 = 60; - -pub fn terminal_width() -> u64 { - if let Some((width, _)) = dimensions() { - width as u64 - } else { - // for envrionment in which atty is not available, - // example, at ci of osx - MIN_TERMINAL_WIDTH - } -} - -const SIZES: [&'static str; 4] = ["B", "K", "M", "G"]; - -pub trait SizeOfFmt { - fn sizeof_fmt(&self) -> String; -} - -impl SizeOfFmt for u64 { - fn sizeof_fmt(&self) -> String { - let mut num = *self as f64; - for s in SIZES.iter() { - if num < 1024.0 { - return format!("{:.1}{}", num, s); - } - num /= 1024.0; - } - return format!("{:.1}{}", num, "G"); - } -} - -impl SizeOfFmt for f64 { - fn sizeof_fmt(&self) -> String { - let mut num = *self; - for s in SIZES.iter() { - if num < 1024.0 { - return format!("{:.1}{}", num, s); - } - num /= 1024.0; - } - return format!("{:.1}{}", num, "G"); - } -} - -pub trait LiteralSize { - fn literal_size(&self) -> Result; -} - -impl LiteralSize for &str { - fn literal_size(&self) -> Result { - let (num, unit) = self.split_at(self.len() - 1); - if unit.parse::().is_err() { - let mut num = num.parse::()?; - for s in SIZES.iter() { - if s.to_lowercase() == unit.to_lowercase() { - return Ok(num); - } else { - num *= 1024; - } - } - Ok(num) - } else { - let num = self.parse::()?; - Ok(num) - } - } -} - -pub trait TimeOfFmt { - fn timeof_fmt(&self) -> String; -} - -impl TimeOfFmt for u64 { - fn timeof_fmt(&self) -> String { - let mut num = *self as f64; - if num < 60.0 { - return format!("{:.0}s", num); - } - num /= 60.0; - if num < 60.0 { - return format!("{:.0}m", num); - } - num /= 60.0; - if num < 24.0 { - return format!("{:.0}h", num); - } - num /= 24.0; - return format!("{:.0}d", num); - } -} - -pub static mut DEBUG: bool = false; -pub static mut QUIET: bool = false; - -#[macro_export] -macro_rules! print_err { - ( $ctx:expr, $err:expr ) => { - eprintln!("[{}:{}] {}: {}", file!(), line!(), $ctx, $err); - }; -} - -#[macro_export] -macro_rules! debug { - ( $title:expr, $msg:expr ) => { - unsafe { - if crate::util::DEBUG { - eprintln!("[{}:{}] {}: {:#?}", file!(), line!(), $title, $msg); - } - } - }; - ( $title:expr ) => { - unsafe { - if crate::util::DEBUG { - eprintln!("[{}:{}] {}", file!(), line!(), $title); - } - } - }; -} - -const FRAGMENT: &AsciiSet = &CONTROLS - .add(b' ') - .remove(b'?') - .remove(b'/') - .remove(b':') - .remove(b'='); - -pub fn escape_nonascii(target: &str) -> String { - utf8_percent_encode(target, FRAGMENT).to_string() -}