From ea051d88730222e3151dd8d5b6c3b56c252a336a Mon Sep 17 00:00:00 2001 From: PeterDing Date: Sat, 23 May 2020 10:29:47 +0800 Subject: [PATCH] Reconstruct with isahc and async-std `isahc` uses libcurl as backend which handles complicated HTTP protocol. The old codes have many messy construction. At this commit, we rewrite all codes and let them have independent relationship as much as possibe. Now, aget-rs supports to asynchronically download http/s content and m3u8 stream videos. --- .gitignore | 1 + .travis.yml | 20 +- Cargo.toml | 25 +- src/app.rs | 206 ---------------- src/app/core/http.rs | 397 +++++++++++++++++++++++++++++++ src/app/core/m3u8/common.rs | 150 ++++++++++++ src/app/core/m3u8/m3u8.rs | 229 ++++++++++++++++++ src/app/core/m3u8/mod.rs | 2 + src/app/core/mod.rs | 2 + src/app/mod.rs | 5 + src/app/receive/http_receiver.rs | 126 ++++++++++ src/app/receive/m3u8_receiver.rs | 90 +++++++ src/app/receive/mod.rs | 2 + src/app/show/common.rs | 11 + src/app/show/http_show.rs | 123 ++++++++++ src/app/show/m3u8_show.rs | 119 +++++++++ src/app/show/mod.rs | 3 + src/app/stats/list_stats.rs | 72 ++++++ src/app/stats/mod.rs | 2 + src/app/stats/range_stats.rs | 194 +++++++++++++++ src/app/status/mod.rs | 1 + src/app/status/rate_status.rs | 63 +++++ src/{ => arguments}/clap_app.rs | 17 +- src/arguments/cmd_args.rs | 239 +++++++++++++++++++ src/arguments/mod.rs | 2 + src/chunk.rs | 40 ---- src/common.rs | 1 - src/common/buf.rs | 2 + src/common/bytes/bytes.rs | 25 ++ src/common/bytes/bytes_type.rs | 1 + src/common/bytes/mod.rs | 2 + src/common/character.rs | 24 ++ src/common/colors.rs | 1 + src/common/crypto.rs | 8 + src/common/debug.rs | 27 +++ src/common/errors.rs | 91 +++++++ src/common/file.rs | 107 +++++++++ src/common/liberal.rs | 54 +++++ src/common/list.rs | 30 +++ src/common/mod.rs | 18 ++ src/common/net/mod.rs | 2 + src/common/net/net.rs | 184 ++++++++++++++ src/common/net/net_type.rs | 16 ++ src/common/range.rs | 70 ++++++ src/common/size.rs | 34 +++ src/common/tasks.rs | 5 + src/common/terminal.rs | 13 + src/common/uri.rs | 22 ++ src/core.rs | 314 ------------------------ src/error.rs | 182 -------------- src/features/args.rs | 52 ++++ src/features/mod.rs | 3 + src/features/running.rs | 5 + src/features/stack.rs | 7 + src/main.rs | 125 ++++------ src/printer.rs | 156 ------------ src/request.rs | 198 --------------- src/store.rs | 362 ---------------------------- src/task.rs | 103 -------- src/util.rs | 154 ------------ 60 files changed, 2726 insertions(+), 1813 deletions(-) delete mode 100644 src/app.rs create mode 100644 src/app/core/http.rs create mode 100644 src/app/core/m3u8/common.rs create mode 100644 src/app/core/m3u8/m3u8.rs create mode 100644 src/app/core/m3u8/mod.rs create mode 100644 src/app/core/mod.rs create mode 100644 src/app/mod.rs create mode 100644 src/app/receive/http_receiver.rs create mode 100644 src/app/receive/m3u8_receiver.rs create mode 100644 src/app/receive/mod.rs create mode 100644 src/app/show/common.rs create mode 100644 src/app/show/http_show.rs create mode 100644 src/app/show/m3u8_show.rs create mode 100644 src/app/show/mod.rs create mode 100644 src/app/stats/list_stats.rs create mode 100644 src/app/stats/mod.rs create mode 100644 src/app/stats/range_stats.rs create mode 100644 src/app/status/mod.rs create mode 100644 src/app/status/rate_status.rs rename src/{ => arguments}/clap_app.rs (87%) create mode 100644 src/arguments/cmd_args.rs create mode 100644 src/arguments/mod.rs delete mode 100644 src/chunk.rs delete mode 100644 src/common.rs create mode 100644 src/common/buf.rs create mode 100644 src/common/bytes/bytes.rs create mode 100644 src/common/bytes/bytes_type.rs create mode 100644 src/common/bytes/mod.rs create mode 100644 src/common/character.rs create mode 100644 src/common/colors.rs create mode 100644 src/common/crypto.rs create mode 100644 src/common/debug.rs create mode 100644 src/common/errors.rs create mode 100644 src/common/file.rs create mode 100644 src/common/liberal.rs create mode 100644 src/common/list.rs create mode 100644 src/common/mod.rs create mode 100644 src/common/net/mod.rs create mode 100644 src/common/net/net.rs create mode 100644 src/common/net/net_type.rs create mode 100644 src/common/range.rs create mode 100644 src/common/size.rs create mode 100644 src/common/tasks.rs create mode 100644 src/common/terminal.rs create mode 100644 src/common/uri.rs delete mode 100644 src/core.rs delete mode 100644 src/error.rs create mode 100644 src/features/args.rs create mode 100644 src/features/mod.rs create mode 100644 src/features/running.rs create mode 100644 src/features/stack.rs delete mode 100644 src/printer.rs delete mode 100644 src/request.rs delete mode 100644 src/store.rs delete mode 100644 src/task.rs delete mode 100644 src/util.rs diff --git a/.gitignore b/.gitignore index ea8c4bf..96ef6c0 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +Cargo.lock diff --git a/.travis.yml b/.travis.yml index 589f77f..2320e28 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,14 +4,19 @@ dist: xenial matrix: include: - # Stable channel. + # Build on linux - os: linux - rust: stable - env: TARGET=x86_64-unknown-linux-gnu + rust: nightly + env: + - TARGET=x86_64-unknown-linux-gnu + - DEPLOY=true + # Build on osx - os: osx - rust: stable - env: TARGET=x86_64-apple-darwin + rust: nightly + env: + - TARGET=x86_64-apple-darwin + - DEPLOY=true # Code formatting check - os: linux @@ -21,8 +26,9 @@ matrix: - cargo install --debug --force rustfmt-nightly script: cargo fmt -- --check + # Test benchmark - os: linux - rust: stable + rust: nightly install: cargo build script: bash ci/benchmark.bash @@ -75,7 +81,7 @@ deploy: # deploy only if we push a tag tags: true # deploy only on stable channel that has TARGET env variable sets - condition: $TRAVIS_RUST_VERSION = stable && $TARGET != "" + condition: $DEPLOY = true && $TARGET != "" notifications: email: diff --git a/Cargo.toml b/Cargo.toml index 018bc5a..b8f5de3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,19 +19,18 @@ edition = "2018" name = "ag" path = "src/main.rs" +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + [dependencies] -percent-encoding = "2.1" -term_size = "0.3" -ansi_term = "0.12" bytes = "0.5" -failure = "0.1" - +isahc = { version = "0.9", features = ["cookies"] } futures = "0.3" -futures-util = "0.3" -actix-rt = "1.0" -awc = { version = "1.0", features = ["openssl"] } - -[dependencies.clap] -version = "2.33" -default-features = false -features = [ "suggestions", "color" ] +async-std = { version = "1", features = ["unstable"] } +thiserror = "1.0" +term_size = "0.3" +ansi_term = "0.12" +openssl = "0.10" +clap = "2" +percent-encoding = "2" +m3u8-rs = "1" +url = "2" diff --git a/src/app.rs b/src/app.rs deleted file mode 100644 index 62ae172..0000000 --- a/src/app.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{env, path::Path}; - -#[cfg(windows)] -use ansi_term::enable_ansi_support; - -use clap::ArgMatches; - -use percent_encoding::percent_decode; - -use awc::http::Uri; - -use crate::{ - clap_app::build_app, - common::AGET_EXT, - error::{ArgError, Result}, - util::{escape_nonascii, LiteralSize}, -}; - -#[derive(Debug, Clone)] -pub struct Config { - pub(crate) uri: String, - pub(crate) method: String, - pub(crate) headers: Vec, - pub(crate) data: Option, - pub(crate) path: String, - pub(crate) concurrency: u64, - pub(crate) chunk_length: u64, - pub(crate) timeout: u64, - pub(crate) max_retries: u32, - pub(crate) retry_wait: u64, - pub(crate) debug: bool, - pub(crate) quiet: bool, -} - -impl Config { - pub fn new( - uri: String, - method: String, - headers: Vec, - data: Option, - path: String, - concurrency: u64, - timeout: u64, - chunk_length: u64, - max_retries: u32, - retry_wait: u64, - debug: bool, - quiet: bool, - ) -> Config { - Config { - uri, - method, - headers, - data, - path, - concurrency, - timeout, - chunk_length, - max_retries, - retry_wait, - debug, - quiet, - } - } -} - -pub struct App { - pub matches: ArgMatches<'static>, -} - -impl App { - pub fn new() -> App { - #[cfg(windows)] - let _ = enable_ansi_support(); - - App { - matches: Self::matches(), - } - } - - fn matches() -> ArgMatches<'static> { - let args = env::args(); - let matches = build_app().get_matches_from(args); - matches - } - - pub fn config(&self) -> Result { - // uri - let uri = self.matches.value_of("URL").map(escape_nonascii).unwrap(); - - // path - let path = if let Some(path) = self.matches.value_of("out") { - path.to_string() - } else { - let uri = &uri.parse::()?; - let path = Path::new(uri.path()); - if let Some(file_name) = path.file_name() { - percent_decode(file_name.to_str().unwrap().as_bytes()) - .decode_utf8() - .unwrap() - .to_string() - } else { - return Err(ArgError::NoFilename); - } - }; - - // check status of task - let path_ = Path::new(&path); - let mut file_name = path_.file_name().unwrap().to_os_string(); - file_name.push(AGET_EXT); - let mut aget_path = path_.to_path_buf(); - aget_path.set_file_name(file_name); - if path_.is_dir() { - return Err(ArgError::PathIsDirectory); - } - if path_.exists() && !aget_path.as_path().exists() { - return Err(ArgError::FileExists); - } - - let path = path.to_string(); - - // data - let data = if let Some(data) = self.matches.value_of("data") { - Some(data.to_string()) - } else { - None - }; - - // method - let method = if let Some(method) = self.matches.value_of("method") { - method.to_string() - } else { - if data.is_some() { - "POST".to_owned() - } else { - "GET".to_owned() - } - }; - - // headers - let headers = if let Some(headers) = self.matches.values_of("header") { - headers.map(String::from).collect::>() - } else { - Vec::new() - }; - - // concurrency - let concurrency = if let Some(concurrency) = self.matches.value_of("concurrency") { - concurrency.parse::()? - } else { - 10 - }; - - // chunk length - let chunk_length = if let Some(chunk_length) = self.matches.value_of("chunk-length") { - chunk_length.literal_size()? - } else { - 1024 * 500 // 500k - }; - - // timeout - let timeout = if let Some(timeout) = self.matches.value_of("timeout") { - timeout.parse::()? - } else { - 10 - }; - - // maximum retries - let max_retries = if let Some(max_retries) = self.matches.value_of("max_retries") { - max_retries.parse::()? - } else { - 5 - }; - - let retry_wait = if let Some(retry_wait) = self.matches.value_of("retry_wait") { - retry_wait.parse::()? - } else { - 5 - }; - - let debug = self.matches.is_present("debug"); - - let quiet = self.matches.is_present("quiet"); - - Ok(Config::new( - uri, - method, - headers, - data, - path, - concurrency, - timeout, - chunk_length, - max_retries, - retry_wait, - debug, - quiet, - )) - } -} - -fn test_escape_nonascii() { - let s = ":ss/s 来;】/ 【【 ? 是的 & 水电费=45 进来看"; - println!("{}", s); - println!("{}", escape_nonascii(s)); -} diff --git a/src/app/core/http.rs b/src/app/core/http.rs new file mode 100644 index 0000000..1cf83fb --- /dev/null +++ b/src/app/core/http.rs @@ -0,0 +1,397 @@ +use std::{path::PathBuf, sync::Arc}; + +use async_std::{io::ReadExt, process::exit, task as std_task}; + +use futures::{ + channel::mpsc::{channel, Sender}, + SinkExt, +}; + +use crate::{ + app::{ + receive::http_receiver::HttpReceiver, + stats::range_stats::{RangeStats, RANGESTATS_FILE_SUFFIX}, + }, + common::{ + buf::SIZE, + bytes::bytes_type::{Buf, Bytes, BytesMut}, + errors::{Error, Result}, + file::File, + net::{ + net::{build_http_client, content_length, redirect, request}, + net_type::{ContentLengthValue, HttpClient, Method, Uri}, + }, + range::{split_pair, RangePair, SharedRangList}, + }, + features::{args::Args, running::Runnable, stack::StackLike}, +}; + +/// Http task handler +pub struct HttpHandler { + output: PathBuf, + method: Method, + uri: Uri, + headers: Vec<(String, String)>, + data: Option, + timeout: u64, + concurrency: u64, + chunk_size: u64, + retries: u64, + retry_wait: u64, + proxy: Option, + client: Arc, +} + +impl HttpHandler { + pub fn new(args: &impl Args) -> Result { + let headers = args.headers(); + let timeout = args.timeout(); + let proxy = args.proxy(); + + let hds: Vec<(&str, &str)> = headers + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + let client = build_http_client(hds.as_ref(), timeout, proxy.as_deref())?; + + debug!("HttpHandler::new"); + + Ok(HttpHandler { + output: args.output(), + method: args.method(), + uri: args.uri(), + headers, + data: args.data().map(|ref mut d| d.to_bytes()), + timeout, + concurrency: args.concurrency(), + chunk_size: args.chunk_length(), + retries: args.retries(), + retry_wait: args.retry_wait(), + proxy, + client: Arc::new(client), + }) + } + + async fn start(&mut self) -> Result<()> { + debug!("HttpHandler::start"); + + // 0. Check whether task is completed + debug!("HttpHandler: check whether task is completed"); + let mut rangestats = + RangeStats::new(&*(self.output.to_string_lossy() + RANGESTATS_FILE_SUFFIX))?; + if self.output.exists() && !rangestats.exists() { + return Ok(()); + } + + // 1. Redirect + debug!("HttpHandler: redirect start"); + let uri = redirect( + &self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + ) + .await?; + debug!("HttpHandler: redirect to", uri); + self.uri = uri; + + // 2. get content_length + debug!("HttpHandler: content_length start"); + let cl = content_length( + &self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + ) + .await?; + debug!("HttpHandler: content_length", cl); + + // 3. Compare recorded content length with the above one + debug!("HttpHandler: compare recorded content length"); + let mut direct = true; + if let ContentLengthValue::RangeLength(cl) = cl { + if self.output.exists() { + if rangestats.exists() { + rangestats.open()?; + } else { + // Task is completed + return Ok(()); + } + } else { + // Init rangestats + rangestats.remove().unwrap_or(()); // Missing error + rangestats.open()?; + } + + let pre_cl = rangestats.total()?; + + // Inital rangestats + if pre_cl == 0 && pre_cl != cl { + rangestats.write_total(cl)?; + direct = false; + } + // Content is empty + else if pre_cl == 0 && pre_cl == cl { + File::new(&self.output, true)?.open()?; + rangestats.remove()?; + return Ok(()); + } + // Content length is not consistent + else if pre_cl != 0 && pre_cl != cl { + return Err(Error::ContentLengthIsNotConsistent); + } + // Rewrite statistic status + else if pre_cl != 0 && pre_cl == cl { + rangestats.rewrite()?; + direct = false; + } + } + + // 4. Create channel + let (sender, receiver) = channel::<(RangePair, Bytes)>(self.concurrency as usize + 10); + + // 5. Dispatch Task + debug!("HttpHandler: dispatch task: direct", direct); + if direct { + let mut task = DirectRequestTask::new( + self.client.clone(), + self.method.clone(), + self.uri.clone(), + self.data.clone(), + sender.clone(), + ); + std_task::spawn(async move { + task.start().await; + }); + } else { + // Make range pairs stack + let mut stack = vec![]; + let gaps = rangestats.gaps()?; + for gap in gaps.iter() { + let mut list = split_pair(gap, self.chunk_size); + stack.append(&mut list); + } + stack.reverse(); + let stack = SharedRangList::new(stack); + // let stack = SharedRangList::new(rangestats.gaps()?); + debug!("HttpHandler: range stack length", stack.len()); + + let concurrency = std::cmp::min(stack.len() as u64, self.concurrency); + for i in 1..concurrency + 1 { + let mut task = RangeRequestTask::new( + self.client.clone(), + self.method.clone(), + self.uri.clone(), + self.data.clone(), + stack.clone(), + sender.clone(), + i, + ); + std_task::spawn(async move { + task.start().await; + }); + } + } + drop(sender); // Remove the reference and let `Task` to handle it + + // 6. Create receiver + debug!("HttpHandler: create receiver"); + let mut httpreceiver = HttpReceiver::new(&self.output, direct)?; + httpreceiver.start(receiver).await?; + + // 7. Task succeeds. Remove rangestats file + rangestats.remove().unwrap_or(()); // Missing error + Ok(()) + } +} + +impl Runnable for HttpHandler { + fn run(&mut self) -> Result<()> { + std_task::block_on(self.start()) + } +} + +/// Directly request the resource without range header +struct DirectRequestTask { + client: Arc, + method: Method, + uri: Uri, + data: Option, + sender: Sender<(RangePair, Bytes)>, +} + +impl DirectRequestTask { + fn new( + client: Arc, + method: Method, + uri: Uri, + data: Option, + sender: Sender<(RangePair, Bytes)>, + ) -> DirectRequestTask { + DirectRequestTask { + client, + method, + uri, + data, + sender, + } + } + + async fn start(&mut self) { + loop { + let resp = request( + &*self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + None, + ) + .await; + if let Err(err) = resp { + print_err!("DirectRequestTask request error", err); + continue; + } + let resp = resp.unwrap(); + + let mut buf = [0; SIZE]; + let mut offset = 0; + let mut reader = resp.into_body(); + loop { + match reader.read(&mut buf).await { + Ok(0) => { + return; + } + Ok(len) => { + let pair = RangePair::new(offset, offset + len as u64 - 1); // The pair is a closed interval + let mut b = BytesMut::from(&buf[..len]); + if self.sender.send((pair, b.to_bytes())).await.is_err() { + break; + } + offset += len as u64; + } + Err(err) => { + print_err!("DirectRequestTask read error", err); + break; + } + } + } + } + } +} + +/// Request the resource with a range header which is in the `SharedRangList` +struct RangeRequestTask { + client: Arc, + method: Method, + uri: Uri, + data: Option, + stack: SharedRangList, + sender: Sender<(RangePair, Bytes)>, + id: u64, +} + +impl RangeRequestTask { + fn new( + client: Arc, + method: Method, + uri: Uri, + data: Option, + stack: SharedRangList, + sender: Sender<(RangePair, Bytes)>, + id: u64, + ) -> RangeRequestTask { + RangeRequestTask { + client, + method, + uri, + data, + stack, + sender, + id, + } + } + + async fn start(&mut self) { + while let Some(pair) = self.stack.pop() { + match self.req(pair).await { + // Exit whole process when `Error::InnerError` is returned + Err(Error::InnerError(msg)) => { + print_err!(format!("RangeRequestTask {}: InnerError", self.id), msg); + exit(1); + } + Err(err) => { + print_err!(format!("RangeRequestTask {}: error", self.id), err); + } + _ => {} + } + } + } + + async fn req(&mut self, pair: RangePair) -> Result<()> { + let resp = request( + &*self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + Some(pair), + ) + .await; + + if let Err(err) = resp { + self.stack.push(pair); + return Err(err); + } + let resp = resp.unwrap(); + + let length = pair.length(); + let mut count = 0; + let mut buf = [0; SIZE]; + let mut offset = pair.begin; + let mut reader = resp.into_body(); + loop { + // Reads some bytes from the byte stream. + // + // Returns the number of bytes read from the start of the buffer. + // + // If the return value is `Ok(n)`, then it must be guaranteed that + // `0 <= n <= buf.len()`. A nonzero `n` value indicates that the buffer has been + // filled in with `n` bytes of data. If `n` is `0`, then it can indicate one of two + // scenarios: + // + // 1. This reader has reached its "end of file" and will likely no longer be able to + // produce bytes. Note that this does not mean that the reader will always no + // longer be able to produce bytes. + // 2. The buffer specified was 0 bytes in length. + match reader.read(&mut buf).await { + Ok(0) => { + if count != length { + let pr = RangePair::new(offset, pair.end); + self.stack.push(pr); + return Err(Error::UncompletedRead); + } else { + return Ok(()); + } + } + Ok(len) => { + let pr = RangePair::new(offset, offset + len as u64 - 1); // The pair is a closed interval + let mut b = BytesMut::from(&buf[..len]); + if let Err(err) = self.sender.send((pr, b.to_bytes())).await { + let pr = RangePair::new(offset, pair.end); + self.stack.push(pr); + return Err(Error::InnerError(format!( + "Error at `http::RangeRequestTask`: Sender error: {:?}", + err + ))); + } + offset += len as u64; + count += len as u64; + } + Err(err) => { + let pr = RangePair::new(offset, pair.end); + self.stack.push(pr); + return Err(err.into()); + } + } + } + } +} diff --git a/src/app/core/m3u8/common.rs b/src/app/core/m3u8/common.rs new file mode 100644 index 0000000..690dcad --- /dev/null +++ b/src/app/core/m3u8/common.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; + +use m3u8_rs::{parse_playlist_res, playlist::Playlist}; + +use crate::common::{ + bytes::{ + bytes::{decode_hex, u32_to_u8x4}, + bytes_type::Bytes, + }, + errors::{Error, Result}, + list::SharedVec, + net::{ + net::{redirect, request}, + net_type::{HttpClient, Method, ResponseExt, Uri, Url}, + }, +}; + +#[derive(Debug, Clone)] +pub struct M3u8Segment { + pub index: u64, + pub method: Method, + pub uri: Uri, + pub data: Option, + pub key: Option<[u8; 16]>, + pub iv: Option<[u8; 16]>, +} + +impl M3u8Segment { + pub fn new( + index: u64, + method: Method, + uri: Uri, + data: Option, + key: Option<[u8; 16]>, + iv: Option<[u8; 16]>, + ) -> M3u8Segment { + M3u8Segment { + index, + method, + uri, + data, + key, + iv, + } + } +} + +pub type M3u8SegmentList = Vec; + +pub type SharedM3u8SegmentList = SharedVec; + +pub async fn get_m3u8( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, +) -> Result { + // uri -> (key, iv) + let mut keymap: HashMap = HashMap::new(); + let mut uris = vec![uri]; + let mut list = vec![]; + + while let Some(uri) = uris.pop() { + debug!("m3u8", uri); + let u = redirect(client, method.clone(), uri.clone(), data.clone()).await?; + debug!("m3u8 redirect to", u); + + if u != uri { + uris.push(u.clone()); + continue; + } + + let base_url = Url::parse(&format!("{:?}", u))?; + + // Read m3u8 content + let mut resp = request(client, method.clone(), u.clone(), data.clone(), None).await?; + + // Adding "\n" for the case when response content has not "\n" at end. + let cn = resp.text()? + "\n"; + + // Parse m3u8 content + let parsed = parse_playlist_res(cn.as_bytes()); + match parsed { + Ok(Playlist::MasterPlaylist(mut pl)) => { + pl.variants.reverse(); + for variant in &pl.variants { + let url = base_url.join(&variant.uri)?; + let uri: Uri = url.as_str().parse()?; + uris.push(uri); + } + } + Ok(Playlist::MediaPlaylist(pl)) => { + let mut index = pl.media_sequence as u64; + for segment in &pl.segments { + let seg_url = base_url.join(&segment.uri)?; + let seg_uri: Uri = seg_url.as_str().parse()?; + + let (key, iv) = if let Some(key) = &segment.key { + let iv = if let Some(iv) = &key.iv { + let mut i = [0; 16]; + let buf = decode_hex(&iv[2..])?; + i.clone_from_slice(&buf[..]); + i + } else { + let mut iv = [0; 16]; + let index_bin = u32_to_u8x4(index as u32); + iv[12..].clone_from_slice(&index_bin); + iv + }; + if let Some(uri) = &key.uri { + let key_url = base_url.join(&uri)?; + let key_uri: Uri = key_url.as_str().parse()?; + if let Some((k, iv)) = keymap.get(&key_uri) { + (Some(*k), Some(*iv)) + } else { + let k = get_key(client, Method::GET, key_uri.clone()).await?; + keymap.insert(key_uri.clone(), (k, iv)); + debug!("Get key, iv", (k, iv)); + (Some(k), Some(iv)) + } + } else { + (None, None) + } + } else { + (None, None) + }; + + list.push(M3u8Segment::new( + index, + Method::GET, + seg_uri.clone(), + None, + key, + iv, + )); + index += 1; + } + } + Err(_) => return Err(Error::M3U8ParseFail), + } + } + Ok(list) +} + +async fn get_key(client: &HttpClient, method: Method, uri: Uri) -> Result<[u8; 16]> { + let mut resp = request(client, method.clone(), uri.clone(), None, None).await?; + let mut cn = [0; 16]; + resp.copy_to(&mut cn[..])?; + Ok(cn) +} diff --git a/src/app/core/m3u8/m3u8.rs b/src/app/core/m3u8/m3u8.rs new file mode 100644 index 0000000..6059168 --- /dev/null +++ b/src/app/core/m3u8/m3u8.rs @@ -0,0 +1,229 @@ +use std::{ + path::PathBuf, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use async_std::{io::ReadExt, process::exit, task as std_task}; + +use futures::{ + channel::mpsc::{channel, Sender}, + SinkExt, +}; + +use crate::{ + app::{ + core::m3u8::common::{get_m3u8, M3u8Segment, SharedM3u8SegmentList}, + receive::m3u8_receiver::M3u8Receiver, + stats::list_stats::{ListStats, LISTSTATS_FILE_SUFFIX}, + }, + common::{ + bytes::bytes_type::{Buf, Bytes}, + crypto::decrypt_aes128, + errors::{Error, Result}, + net::{ + net::{build_http_client, request}, + net_type::{HttpClient, Method, Uri}, + }, + }, + features::{args::Args, running::Runnable, stack::StackLike}, +}; + +/// M3u8 task handler +pub struct M3u8Handler { + output: PathBuf, + method: Method, + uri: Uri, + headers: Vec<(String, String)>, + data: Option, + timeout: u64, + concurrency: u64, + proxy: Option, + client: Arc, +} + +impl M3u8Handler { + pub fn new(args: &impl Args) -> Result { + let headers = args.headers(); + let timeout = args.timeout(); + let proxy = args.proxy(); + + let hds: Vec<(&str, &str)> = headers + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())) + .collect(); + let client = build_http_client(hds.as_ref(), timeout, proxy.as_deref())?; + + debug!("M3u8Handler::new"); + + Ok(M3u8Handler { + output: args.output(), + method: args.method(), + uri: args.uri(), + headers, + data: args.data().map(|ref mut d| d.to_bytes()), + timeout, + concurrency: args.concurrency(), + proxy, + client: Arc::new(client), + }) + } + + async fn start(&mut self) -> Result<()> { + debug!("M3u8Handler::start"); + + // 0. Check whether task is completed + debug!("M3u8Handler: check whether task is completed"); + let mut liststats = + ListStats::new(&*(self.output.to_string_lossy() + LISTSTATS_FILE_SUFFIX))?; + if self.output.exists() && !liststats.exists() { + return Ok(()); + } + + // 1. Redirect + debug!("M3u8Handler: get m3u8"); + let mut ls = get_m3u8( + &self.client, + self.method.clone(), + self.uri.clone(), + self.data.clone(), + ) + .await?; + ls.reverse(); + + if liststats.exists() { + liststats.open()?; + let total = liststats.total()?; + if total != ls.len() as u64 { + return Err(Error::PartsAreNotConsistent); + } else { + let index = liststats.index()?; + ls.truncate((total - index) as usize); + } + } else { + liststats.open()?; + liststats.write_total(ls.len() as u64)?; + } + + let index = ls.last().unwrap().index; + let sharedindex = Arc::new(AtomicU64::new(index)); + let stack = SharedM3u8SegmentList::new(ls); + debug!("M3u8Handler: segments", stack.len()); + + // 4. Create channel + let (sender, receiver) = channel::<(u64, Bytes)>(self.concurrency as usize + 10); + + let concurrency = std::cmp::min(stack.len() as u64, self.concurrency); + for i in 1..concurrency + 1 { + let mut task = RequestTask::new( + self.client.clone(), + stack.clone(), + sender.clone(), + i, + sharedindex.clone(), + ); + std_task::spawn(async move { + task.start().await; + }); + } + drop(sender); // Remove the reference and let `Task` to handle it + + // 6. Create receiver + debug!("M3u8Handler: create receiver"); + let mut m3u8receiver = M3u8Receiver::new(&self.output)?; + m3u8receiver.start(receiver).await?; + + // 7. Task succeeds. Remove liststats file + liststats.remove().unwrap_or(()); // Missing error + Ok(()) + } +} + +impl Runnable for M3u8Handler { + fn run(&mut self) -> Result<()> { + std_task::block_on(self.start()) + } +} + +/// Request the resource with a range header which is in the `SharedRangList` +struct RequestTask { + client: Arc, + stack: SharedM3u8SegmentList, + sender: Sender<(u64, Bytes)>, + id: u64, + shared_index: Arc, +} + +impl RequestTask { + fn new( + client: Arc, + stack: SharedM3u8SegmentList, + sender: Sender<(u64, Bytes)>, + id: u64, + sharedindex: Arc, + ) -> RequestTask { + RequestTask { + client, + stack, + sender, + id, + shared_index: sharedindex, + } + } + + async fn start(&mut self) { + while let Some(segment) = self.stack.pop() { + loop { + match self.req(segment.clone()).await { + // Exit whole process when `Error::InnerError` is returned + Err(Error::InnerError(msg)) => { + print_err!(format!("RequestTask {}: InnerError", self.id), msg); + exit(1); + } + Err(err) => { + print_err!(format!("RequestTask {}: error", self.id), err); + std_task::sleep(Duration::from_secs(1)).await; + } + _ => break, + } + } + } + } + + async fn req(&mut self, segment: M3u8Segment) -> Result<()> { + let resp = request( + &*self.client, + segment.method.clone(), + segment.uri.clone(), + segment.data.clone(), + None, + ) + .await?; + + let index = segment.index; + let mut buf: Vec = vec![]; + let mut reader = resp.into_body(); + reader.read_to_end(&mut buf).await?; + if let (Some(key), Some(iv)) = (segment.key, segment.iv) { + buf = decrypt_aes128(&key[..], &iv[..], &buf[..])?; + } + + loop { + if self.shared_index.load(Ordering::SeqCst) == index { + if let Err(err) = self.sender.send((index, Bytes::from(buf))).await { + return Err(Error::InnerError(format!( + "Error at `http::RequestTask`: Sender error: {:?}", + err + ))); + } + self.shared_index.store(index + 1, Ordering::SeqCst); + return Ok(()); + } else { + std_task::sleep(Duration::from_millis(500)).await; + } + } + } +} diff --git a/src/app/core/m3u8/mod.rs b/src/app/core/m3u8/mod.rs new file mode 100644 index 0000000..10dfd4e --- /dev/null +++ b/src/app/core/m3u8/mod.rs @@ -0,0 +1,2 @@ +pub mod common; +pub mod m3u8; diff --git a/src/app/core/mod.rs b/src/app/core/mod.rs new file mode 100644 index 0000000..1330dcd --- /dev/null +++ b/src/app/core/mod.rs @@ -0,0 +1,2 @@ +pub mod http; +pub mod m3u8; diff --git a/src/app/mod.rs b/src/app/mod.rs new file mode 100644 index 0000000..561b64b --- /dev/null +++ b/src/app/mod.rs @@ -0,0 +1,5 @@ +pub mod core; +pub mod receive; +pub mod show; +pub mod stats; +pub mod status; diff --git a/src/app/receive/http_receiver.rs b/src/app/receive/http_receiver.rs new file mode 100644 index 0000000..d7ab4dd --- /dev/null +++ b/src/app/receive/http_receiver.rs @@ -0,0 +1,126 @@ +use std::{io::SeekFrom, path::Path, time::Duration}; + +use async_std::stream; +use futures::{channel::mpsc::Receiver, select, stream::StreamExt}; + +use crate::{ + app::{ + show::http_show::HttpShower, + stats::range_stats::{RangeStats, RANGESTATS_FILE_SUFFIX}, + status::rate_status::RateStatus, + }, + common::{bytes::bytes_type::Bytes, errors::Result, file::File, range::RangePair}, +}; + +pub struct HttpReceiver { + output: File, + rangestats: Option, + ratestatus: RateStatus, + shower: HttpShower, + // Total content length of the uri + total: u64, +} + +impl HttpReceiver { + pub fn new>(output: P, direct: bool) -> Result { + let mut outputfile = File::new(&output, true)?; + outputfile.open()?; + + let (rangestats, total, completed) = if direct { + (None, 0, 0) + } else { + let mut rangestats = + RangeStats::new(&*(output.as_ref().to_string_lossy() + RANGESTATS_FILE_SUFFIX))?; + rangestats.open()?; + let total = rangestats.total()?; + let completed = rangestats.count()?; + (Some(rangestats), total, completed) + }; + + let mut ratestatus = RateStatus::new(); + ratestatus.set_total(completed); + + Ok(HttpReceiver { + output: outputfile, + rangestats, + ratestatus, + shower: HttpShower::new(), + // receiver, + total, + }) + } + + fn show_infos(&mut self) -> Result<()> { + if self.rangestats.is_none() { + self.shower + .print_msg("Server doesn't support range request.")?; + } + + let file_name = &self.output.file_name().unwrap_or("[No Name]"); + let total = self.total; + self.shower.print_file(file_name)?; + self.shower.print_total(total)?; + // self.shower.print_concurrency(concurrency)?; + self.show_status()?; + Ok(()) + } + + fn show_status(&mut self) -> Result<()> { + let total = self.total; + let completed = self.ratestatus.total(); + let rate = self.ratestatus.rate(); + + let eta = if self.rangestats.is_some() { + let remains = total - completed; + // rate > 1.0 for overflow + if remains > 0 && rate > 1.0 { + let eta = (remains as f64 / rate) as u64; + // eta is large than 99 days, return 0 + if eta > 99 * 24 * 60 * 60 { + 0 + } else { + eta + } + } else { + 0 + } + } else { + 0 + }; + + self.shower.print_status(completed, total, rate, eta)?; + self.ratestatus.clean(); + Ok(()) + } + + fn record_pair(&mut self, pair: RangePair) -> Result<()> { + if let Some(ref mut rangestats) = self.rangestats { + rangestats.write_pair(pair)?; + } + Ok(()) + } + pub async fn start(&mut self, receiver: Receiver<(RangePair, Bytes)>) -> Result<()> { + self.show_infos()?; + + let mut tick = stream::interval(Duration::from_secs(2)).fuse(); + let mut receiver = receiver.fuse(); + loop { + select! { + item = receiver.next() => { + if let Some((pair, chunk)) = item { + self.output.write(&chunk[..], Some(SeekFrom::Start(pair.begin)))?; + self.record_pair(pair)?; + self.ratestatus.add(pair.length()); + } else { + break; + } + }, + _ = tick.next() => { + self.show_status()?; + }, + } + } + self.show_status()?; + Ok(()) + } +} diff --git a/src/app/receive/m3u8_receiver.rs b/src/app/receive/m3u8_receiver.rs new file mode 100644 index 0000000..e84a003 --- /dev/null +++ b/src/app/receive/m3u8_receiver.rs @@ -0,0 +1,90 @@ +use std::{io::SeekFrom, path::Path, time::Duration}; + +use async_std::stream; +use futures::{channel::mpsc::Receiver, select, stream::StreamExt}; + +use crate::{ + app::{ + show::m3u8_show::M3u8Shower, + stats::list_stats::{ListStats, LISTSTATS_FILE_SUFFIX}, + status::rate_status::RateStatus, + }, + common::{bytes::bytes_type::Bytes, errors::Result, file::File}, +}; + +pub struct M3u8Receiver { + output: File, + liststats: ListStats, + ratestatus: RateStatus, + shower: M3u8Shower, + // Total number of the `SharedM3u8SegmentList` + total: u64, + completed: u64, +} + +impl M3u8Receiver { + pub fn new>(output: P) -> Result { + let mut outputfile = File::new(&output, true)?; + outputfile.open()?; + + let mut liststats = + ListStats::new(&*(output.as_ref().to_string_lossy() + LISTSTATS_FILE_SUFFIX))?; + liststats.open()?; + let total = liststats.total()?; + let completed = liststats.index()?; + + Ok(M3u8Receiver { + output: outputfile, + liststats, + ratestatus: RateStatus::new(), + shower: M3u8Shower::new(), + total, + completed, + }) + } + + fn show_infos(&mut self) -> Result<()> { + let file_name = &self.output.file_name().unwrap_or("[No Name]"); + let total = self.total; + self.shower.print_file(file_name)?; + self.shower.print_total(total)?; + self.show_status()?; + Ok(()) + } + + fn show_status(&mut self) -> Result<()> { + let total = self.total; + let completed = self.completed; + let rate = self.ratestatus.rate(); + + self.shower.print_status(completed, total, rate)?; + self.ratestatus.clean(); + Ok(()) + } + + pub async fn start(&mut self, receiver: Receiver<(u64, Bytes)>) -> Result<()> { + self.show_infos()?; + + let mut tick = stream::interval(Duration::from_secs(2)).fuse(); + let mut receiver = receiver.fuse(); + loop { + select! { + item = receiver.next() => { + if let Some((index, chunk)) = item { + self.output.write(&chunk[..], Some(SeekFrom::End(0)))?; + self.liststats.write_index(index + 1)?; + self.ratestatus.add(chunk.len() as u64); + self.completed = index + 1; + } else { + break; + } + }, + _ = tick.next() => { + self.show_status()?; + }, + } + } + self.show_status()?; + Ok(()) + } +} diff --git a/src/app/receive/mod.rs b/src/app/receive/mod.rs new file mode 100644 index 0000000..19dc232 --- /dev/null +++ b/src/app/receive/mod.rs @@ -0,0 +1,2 @@ +pub mod http_receiver; +pub mod m3u8_receiver; diff --git a/src/app/show/common.rs b/src/app/show/common.rs new file mode 100644 index 0000000..2eb5369 --- /dev/null +++ b/src/app/show/common.rs @@ -0,0 +1,11 @@ +#[cfg(target_os = "windows")] +pub fn bars() -> (&'static str, &'static str, &'static str) { + // bar, bar_right, bar_left + ("▓", "░", " ") +} + +#[cfg(not(target_os = "windows"))] +pub fn bars() -> (&'static str, &'static str, &'static str) { + // bar, bar_right, bar_left + ("━", "╸", "╺") +} diff --git a/src/app/show/http_show.rs b/src/app/show/http_show.rs new file mode 100644 index 0000000..01e56e2 --- /dev/null +++ b/src/app/show/http_show.rs @@ -0,0 +1,123 @@ +use std::io::{stdout, Stdout, Write}; + +use crate::{ + app::show::common::bars, + common::{ + colors::{Black, Blue, Cyan, Green, Red, Yellow}, + errors::Result, + liberal::ToDate, + size::HumanReadable, + terminal::terminal_width, + }, +}; + +pub struct HttpShower { + stdout: Stdout, +} + +impl HttpShower { + pub fn new() -> HttpShower { + HttpShower { stdout: stdout() } + } + + pub fn print_msg(&mut self, msg: &str) -> Result<()> { + writeln!(&mut self.stdout, "\n {}", Yellow.italic().paint(msg))?; + Ok(()) + } + + pub fn print_file(&mut self, path: &str) -> Result<()> { + writeln!( + &mut self.stdout, + // "\n {}: {}", + "\n{}: {}", + Green.bold().paint("File"), + path, + )?; + Ok(()) + } + + pub fn print_total(&mut self, total: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {} ({})", + Blue.bold().paint("Length"), + total.human_readable(), + total, + )?; + Ok(()) + } + + pub fn print_concurrency(&mut self, concurrency: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {}\n", + Yellow.bold().paint("concurrency"), + concurrency, + )?; + Ok(()) + } + + pub fn print_status(&mut self, completed: u64, total: u64, rate: f64, eta: u64) -> Result<()> { + let percent = completed as f64 / total as f64; + + let completed_str = completed.human_readable(); + let total_str = total.human_readable(); + let percent_str = format!("{:.2}", percent * 100.0); + let rate_str = rate.human_readable(); + let eta_str = eta.date(); + + // maximum info length is 41 e.g. + // 1001.3k/1021.9m 97.98% 1003.1B/s eta: 12s + let info = format!( + "{completed}/{total} {percent}% {rate}/s eta: {eta}", + completed = completed_str, + total = total_str, + percent = percent_str, + rate = rate_str, + eta = eta_str, + ); + + // set default info length + let info_length = 41; + let miss = info_length - info.len(); + + let terminal_width = terminal_width(); + let bar_length = terminal_width - info_length as u64 - 3; + let process_bar_length = (bar_length as f64 * percent) as u64; + let blank_length = bar_length - process_bar_length; + + let (bar, bar_right, bar_left) = bars(); + + let bar_done_str = if process_bar_length > 0 { + format!( + "{}{}", + bar.repeat((process_bar_length - 1) as usize), + bar_right + ) + } else { + "".to_owned() + }; + let bar_undone_str = if blank_length > 0 { + format!("{}{}", bar_left, bar.repeat(blank_length as usize - 1)) + } else { + "".to_owned() + }; + + write!( + &mut self.stdout, + "\r{completed}/{total} {percent}% {rate}/s eta: {eta}{miss} {process_bar}{blank} ", + completed = Red.bold().paint(completed_str), + total = Green.bold().paint(total_str), + percent = Yellow.bold().paint(percent_str), + rate = Blue.bold().paint(rate_str), + eta = Cyan.bold().paint(eta_str), + miss = " ".repeat(miss), + process_bar = Red.bold().paint(bar_done_str), + blank = Black.bold().paint(bar_undone_str), + )?; + + self.stdout.flush()?; + + Ok(()) + } +} diff --git a/src/app/show/m3u8_show.rs b/src/app/show/m3u8_show.rs new file mode 100644 index 0000000..719b687 --- /dev/null +++ b/src/app/show/m3u8_show.rs @@ -0,0 +1,119 @@ +use std::io::{stdout, Stdout, Write}; + +use crate::{ + app::show::common::bars, + common::{ + colors::{Black, Blue, Green, Red, Yellow}, + errors::Result, + size::HumanReadable, + terminal::terminal_width, + }, +}; + +pub struct M3u8Shower { + stdout: Stdout, +} + +impl M3u8Shower { + pub fn new() -> M3u8Shower { + M3u8Shower { stdout: stdout() } + } + + pub fn print_msg(&mut self, msg: &str) -> Result<()> { + writeln!(&mut self.stdout, "\n {}", Yellow.italic().paint(msg))?; + Ok(()) + } + + pub fn print_file(&mut self, path: &str) -> Result<()> { + writeln!( + &mut self.stdout, + // "\n {}: {}", + "\n{}: {}", + Green.bold().paint("File"), + path, + )?; + Ok(()) + } + + pub fn print_total(&mut self, total: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {}", + Blue.bold().paint("Segments"), + total, + )?; + Ok(()) + } + + pub fn print_concurrency(&mut self, concurrency: u64) -> Result<()> { + writeln!( + &mut self.stdout, + "{}: {}\n", + Yellow.bold().paint("concurrency"), + concurrency, + )?; + Ok(()) + } + + pub fn print_status(&mut self, completed: u64, total: u64, rate: f64) -> Result<()> { + let percent = completed as f64 / total as f64; + + let completed_str = completed.to_string(); + let total_str = total.to_string(); + let percent_str = format!("{:.2}", percent * 100.0); + let rate_str = rate.human_readable(); + + // maximum info length is `completed_str.len()` + `total_str.len()` + 19 + // e.g. + // 100/1021 97.98% 1003.1B/s eta: 12s + let info = format!( + "{completed}/{total} {percent}% {rate}/s", + completed = completed_str, + total = total_str, + percent = percent_str, + rate = rate_str, + ); + + // set default info length + let info_length = completed_str.len() + total_str.len() + 19; + let miss = info_length - info.len(); + + let terminal_width = terminal_width(); + let bar_length = terminal_width - info_length as u64 - 3; + let process_bar_length = (bar_length as f64 * percent) as u64; + let blank_length = bar_length - process_bar_length; + + let (bar, bar_right, bar_left) = bars(); + + let bar_done_str = if process_bar_length > 0 { + format!( + "{}{}", + bar.repeat((process_bar_length - 1) as usize), + bar_right + ) + } else { + "".to_owned() + }; + let bar_undone_str = if blank_length > 0 { + format!("{}{}", bar_left, bar.repeat(blank_length as usize - 1)) + } else { + "".to_owned() + }; + + write!( + &mut self.stdout, + "\r{completed}/{total} {percent}% {rate}/s{miss} {process_bar}{blank} ", + completed = Red.bold().paint(completed_str), + total = Green.bold().paint(total_str), + percent = Yellow.bold().paint(percent_str), + rate = Blue.bold().paint(rate_str), + miss = " ".repeat(miss), + process_bar = Red.bold().paint(bar_done_str), + blank = Black.bold().paint(bar_undone_str), + )?; + + self.stdout.flush()?; + + Ok(()) + } +} diff --git a/src/app/show/mod.rs b/src/app/show/mod.rs new file mode 100644 index 0000000..87fee76 --- /dev/null +++ b/src/app/show/mod.rs @@ -0,0 +1,3 @@ +pub mod common; +pub mod http_show; +pub mod m3u8_show; diff --git a/src/app/stats/list_stats.rs b/src/app/stats/list_stats.rs new file mode 100644 index 0000000..fc11ff1 --- /dev/null +++ b/src/app/stats/list_stats.rs @@ -0,0 +1,72 @@ +use std::{io::SeekFrom, path::Path}; + +use crate::common::{ + bytes::bytes::{u64_to_u8x8, u8x8_to_u64}, + errors::Result, + file::File, +}; + +pub const LISTSTATS_FILE_SUFFIX: &'static str = ".ls.aget"; + +/// List statistic +/// +/// `ListStats` struct records total and index two number. +/// All information is stored at a local file. +/// +/// [total 8bit][index 8bit] +/// `total` is given by user, presenting as the real total number of items of a list. +pub struct ListStats { + inner: File, +} + +impl ListStats { + pub fn new>(path: P) -> Result { + let inner = File::new(path, true)?; + Ok(ListStats { inner }) + } + + pub fn open(&mut self) -> Result<&mut Self> { + self.inner.open()?; + Ok(self) + } + + pub fn file_name(&self) -> Option<&str> { + self.inner.file_name() + } + + pub fn exists(&self) -> bool { + self.inner.exists() + } + + /// Delete the inner file + pub fn remove(&self) -> Result<()> { + self.inner.remove() + } + + /// Get downloading file's content length stored in the aget file + pub fn total(&mut self) -> Result { + let mut buf: [u8; 8] = [0; 8]; + self.inner.read(&mut buf, Some(SeekFrom::Start(0)))?; + let cl = u8x8_to_u64(&buf); + Ok(cl) + } + + pub fn index(&mut self) -> Result { + let mut buf: [u8; 8] = [0; 8]; + self.inner.read(&mut buf, Some(SeekFrom::Start(8)))?; + let cl = u8x8_to_u64(&buf); + Ok(cl) + } + + pub fn write_total(&mut self, total: u64) -> Result<()> { + let buf = u64_to_u8x8(total); + self.inner.write(&buf, Some(SeekFrom::Start(0)))?; + Ok(()) + } + + pub fn write_index(&mut self, index: u64) -> Result<()> { + let buf = u64_to_u8x8(index); + self.inner.write(&buf, Some(SeekFrom::Start(8)))?; + Ok(()) + } +} diff --git a/src/app/stats/mod.rs b/src/app/stats/mod.rs new file mode 100644 index 0000000..57fa9e9 --- /dev/null +++ b/src/app/stats/mod.rs @@ -0,0 +1,2 @@ +pub mod list_stats; +pub mod range_stats; diff --git a/src/app/stats/range_stats.rs b/src/app/stats/range_stats.rs new file mode 100644 index 0000000..a12ccba --- /dev/null +++ b/src/app/stats/range_stats.rs @@ -0,0 +1,194 @@ +use std::{cmp::max, io::SeekFrom, path::Path}; + +use crate::common::{ + bytes::bytes::{u64_to_u8x8, u8x8_to_u64}, + errors::Result, + file::File, + range::{RangeList, RangePair}, +}; + +pub const RANGESTATS_FILE_SUFFIX: &'static str = ".rg.aget"; + +/// Range statistic +/// +/// This struct records pairs which are `common::range::RangePair`. +/// All information is stored at a local file. +/// +/// [total 8bit][ [begin1 8bit,end1 8bit] [begin2 8bit,end2 8bit] ... ] +/// `total` position is not sum_i{end_i - begin_i + 1}. It is given by +/// user, presenting as the real total number. +pub struct RangeStats { + inner: File, +} + +impl RangeStats { + pub fn new>(path: P) -> Result { + let inner = File::new(path, true)?; + Ok(RangeStats { inner }) + } + + pub fn open(&mut self) -> Result<&mut Self> { + self.inner.open()?; + Ok(self) + } + + pub fn file_name(&self) -> Option<&str> { + self.inner.file_name() + } + + pub fn exists(&self) -> bool { + self.inner.exists() + } + + /// Delete the inner file + pub fn remove(&self) -> Result<()> { + self.inner.remove() + } + + /// Get downloading file's content length stored in the aget file + pub fn total(&mut self) -> Result { + let mut buf: [u8; 8] = [0; 8]; + self.inner.read(&mut buf, Some(SeekFrom::Start(0)))?; + let cl = u8x8_to_u64(&buf); + Ok(cl) + } + + /// Count the length of total pairs + pub fn count(&mut self) -> Result { + let pairs = self.pairs()?; + if pairs.is_empty() { + Ok(0) + } else { + Ok(pairs.iter().map(RangePair::length).sum()) + } + } + + /// Recorded pairs + pub fn pairs(&mut self) -> Result { + let mut pairs: Vec<(u64, u64)> = Vec::new(); + + let mut buf: [u8; 16] = [0; 16]; + self.inner.seek(SeekFrom::Start(8))?; + loop { + let s = self.inner.read(&mut buf, None)?; + if s != 16 { + break; + } + + let mut raw = [0; 8]; + raw.clone_from_slice(&buf[..8]); + let begin = u8x8_to_u64(&raw); + raw.clone_from_slice(&buf[8..]); + let end = u8x8_to_u64(&raw); + + assert!( + begin <= end, + format!( + "Bug: `begin > end` in an pair of {}. : {} > {}", + self.file_name().unwrap_or(""), + begin, + end + ) + ); + + pairs.push((begin, end)); + } + + pairs.sort(); + + // merge pairs + let mut merged_pairs: Vec<(u64, u64)> = Vec::new(); + if !pairs.is_empty() { + merged_pairs.push(pairs[0]); + } + for (begin, end) in pairs.iter() { + let (pre_start, pre_end) = merged_pairs.last().unwrap().clone(); + + // case 1 + // ---------- + // ----------- + if pre_end + 1 < *begin { + merged_pairs.push((*begin, *end)); + // case 2 + // ----------------- + // ---------- + // -------- + // ------ + } else { + let n_start = pre_start; + let n_end = max(pre_end, *end); + merged_pairs.pop(); + merged_pairs.push((n_start, n_end)); + } + } + + Ok(merged_pairs + .iter() + .map(|(begin, end)| RangePair::new(*begin, *end)) + .collect::()) + } + + /// Get gaps between all pairs + /// Each of gap is a closed interval + pub fn gaps(&mut self) -> Result { + let mut pairs = self.pairs()?; + let total = self.total()?; + pairs.push(RangePair::new(total, total)); + + // find gaps + let mut gaps: RangeList = Vec::new(); + // find first chunk + let RangePair { begin, .. } = pairs[0]; + if begin > 0 { + gaps.push(RangePair::new(0, begin - 1)); + } + + for (index, RangePair { end, .. }) in pairs.iter().enumerate() { + if let Some(RangePair { + begin: next_start, .. + }) = pairs.get(index + 1) + { + if end + 1 < *next_start { + gaps.push(RangePair::new(end + 1, next_start - 1)); + } + } + } + + Ok(gaps) + } + + pub fn write_total(&mut self, total: u64) -> Result<()> { + let buf = u64_to_u8x8(total); + self.inner.write(&buf, Some(SeekFrom::Start(0)))?; + Ok(()) + } + + pub fn write_pair(&mut self, pair: RangePair) -> Result<()> { + let begin = u64_to_u8x8(pair.begin); + let end = u64_to_u8x8(pair.end); + let buf = [begin, end].concat(); + self.inner.write(&buf, Some(SeekFrom::End(0)))?; + Ok(()) + } + + // Merge completed pairs and rewrite the aget file + pub fn rewrite(&mut self) -> Result<()> { + let total = self.total()?; + let pairs = self.pairs()?; + + let mut buf: Vec = Vec::new(); + buf.extend(&u64_to_u8x8(total)); + for pair in pairs.iter() { + buf.extend(&u64_to_u8x8(pair.begin)); + buf.extend(&u64_to_u8x8(pair.end)); + } + + // Clean all content of the file and set its length to zero + self.inner.set_len(0)?; + + // Write new data + self.inner.write(buf.as_slice(), Some(SeekFrom::Start(0)))?; + + Ok(()) + } +} diff --git a/src/app/status/mod.rs b/src/app/status/mod.rs new file mode 100644 index 0000000..2d89a27 --- /dev/null +++ b/src/app/status/mod.rs @@ -0,0 +1 @@ +pub mod rate_status; diff --git a/src/app/status/rate_status.rs b/src/app/status/rate_status.rs new file mode 100644 index 0000000..8771ddb --- /dev/null +++ b/src/app/status/rate_status.rs @@ -0,0 +1,63 @@ +use std::time::Instant; + +/// `RateStatus` records the rate of adding number +pub struct RateStatus { + /// Total number + total: u64, + + /// The number at an one tick interval + count: u64, + + /// The interval of one tick + tick: Instant, +} + +impl RateStatus { + pub fn new() -> RateStatus { + RateStatus::default() + } + + pub fn total(&self) -> u64 { + self.total + } + + pub fn set_total(&mut self, total: u64) { + self.total = total; + } + + pub fn count(&self) -> u64 { + self.count + } + + pub fn rate(&self) -> f64 { + let interval = self.tick.elapsed().as_secs_f64(); + let rate = self.count as f64 / interval; + rate + } + + pub fn add(&mut self, incr: u64) { + self.total += incr; + self.count += incr; + } + + pub fn reset(&mut self) { + self.total = 0; + self.count = 0; + self.tick = Instant::now(); + } + + pub fn clean(&mut self) { + self.count = 0; + self.tick = Instant::now(); + } +} + +impl Default for RateStatus { + fn default() -> RateStatus { + RateStatus { + total: 0, + count: 0, + tick: Instant::now(), + } + } +} diff --git a/src/clap_app.rs b/src/arguments/clap_app.rs similarity index 87% rename from src/clap_app.rs rename to src/arguments/clap_app.rs index 1c337ab..b86c434 100644 --- a/src/clap_app.rs +++ b/src/arguments/clap_app.rs @@ -1,6 +1,6 @@ use clap::{crate_name, crate_version, App as ClapApp, AppSettings, Arg}; -pub fn build_app() -> ClapApp<'static, 'static> { +pub fn build_app<'a>() -> ClapApp<'a, 'a> { ClapApp::new(crate_name!()) .version(crate_version!()) .global_setting(AppSettings::ColoredHelp) @@ -66,12 +66,18 @@ pub fn build_app() -> ClapApp<'static, 'static> { .multiple(false) .takes_value(true), ) + .arg( + Arg::with_name("proxy") + .long("proxy") + .help("proxy (http/https/socks4/socks5) e.g. -p http://localhost:1024") + .multiple(false) + .takes_value(true), + ) .arg( Arg::with_name("timeout") .short("t") .long("timeout") .help("Timeout(seconds) of request") - .default_value("10") .multiple(false) .takes_value(true), ) @@ -91,6 +97,13 @@ pub fn build_app() -> ClapApp<'static, 'static> { .multiple(false) .takes_value(true), ) + .arg( + Arg::with_name("type") + .long("type") + .default_value("http") + .multiple(false) + .help("Task type, http/m3u8"), + ) .arg( Arg::with_name("debug") .long("debug") diff --git a/src/arguments/cmd_args.rs b/src/arguments/cmd_args.rs new file mode 100644 index 0000000..6ccb0a7 --- /dev/null +++ b/src/arguments/cmd_args.rs @@ -0,0 +1,239 @@ +use std::{ + env, fmt, + path::{Path, PathBuf}, +}; + +#[cfg(windows)] +use ansi_term::enable_ansi_support; + +use clap::ArgMatches; + +use percent_encoding::percent_decode; + +use crate::{ + arguments::clap_app::build_app, + common::{ + bytes::bytes_type::BytesMut, + character::escape_nonascii, + errors::Error, + liberal::ParseLiteralNumber, + net::{ + net::parse_headers, + net_type::{Method, Uri}, + }, + tasks::TaskType, + }, + features::args::Args, +}; + +pub struct CmdArgs { + matches: ArgMatches<'static>, +} + +impl CmdArgs { + pub fn new() -> CmdArgs { + #[cfg(windows)] + let _ = enable_ansi_support(); + + let args = env::args(); + let inner = build_app(); + let matches = inner.get_matches_from(args); + CmdArgs { matches } + } +} + +impl Args for CmdArgs { + /// Path of output + fn output(&self) -> PathBuf { + if let Some(path) = self.matches.value_of("out") { + PathBuf::from(path) + } else { + let uri = self.uri(); + let path = Path::new(uri.path()); + if let Some(file_name) = path.file_name() { + PathBuf::from( + percent_decode(file_name.to_str().unwrap().as_bytes()) + .decode_utf8() + .unwrap() + .to_string(), + ) + } else { + panic!("{:?}", Error::NoFilename); + } + } + } + + /// Request method for http + fn method(&self) -> Method { + if let Some(method) = self.matches.value_of("method") { + match method.to_uppercase().as_str() { + "GET" => Method::GET, + "POST" => Method::POST, + _ => panic!(format!( + "{:?}", + Error::UnsupportedMethod(method.to_string()) + )), + } + } else { + if self.data().is_some() { + Method::POST + } else { + Method::GET + } + } + } + + /// The uri of a task + fn uri(&self) -> Uri { + self.matches + .value_of("URL") + .map(escape_nonascii) + .unwrap() + .parse() + .unwrap() + } + + /// The data for http post request + fn data(&self) -> Option { + self.matches.value_of("data").map(|d| BytesMut::from(d)) + } + + /// Request headers + fn headers(&self) -> Vec<(String, String)> { + if let Some(headers) = self.matches.values_of("header") { + parse_headers(headers) + .unwrap() + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect::>() + } else { + vec![] + } + } + + /// Set proxy througth arg or environment variable + /// + /// The environment variables can be: + /// http_proxy [protocol://][:port] + /// Sets the proxy server to use for HTTP. + /// + /// HTTPS_PROXY [protocol://][:port] + /// Sets the proxy server to use for HTTPS. + /// + /// ALL_PROXY [protocol://][:port] + /// Sets the proxy server to use if no protocol-specific proxy is set. + /// + /// Protocols: + /// http:// + /// an HTTP proxy + /// https:// + /// as HTTPS proxy + /// socks4:// + /// socks4a:// + /// socks5:// + /// socks5h:// + /// as SOCKS proxy + fn proxy(&self) -> Option { + let p = self.matches.value_of("proxy").map(|i| i.to_string()); + if p.is_some() { + return p; + } + + if let Ok(p) = env::var("http_proxy") { + return Some(p); + } + + if let Ok(p) = env::var("HTTPS_PROXY") { + return Some(p); + } + + if let Ok(p) = env::var("ALL_PROXY") { + return Some(p); + } + + None + } + + /// The maximum time the request is allowed to take. + fn timeout(&self) -> u64 { + self.matches + .value_of("timeout") + .map(|i| i.parse::().unwrap()) + .unwrap_or(0) + } + + /// The number of concurrency + fn concurrency(&self) -> u64 { + self.matches + .value_of("concurrency") + .map(|i| i.parse::().unwrap()) + .unwrap_or(10) + } + + /// The chunk length of each concurrency for http task + fn chunk_length(&self) -> u64 { + self.matches + .value_of("chunk-length") + .map(|i| i.literal_number().unwrap()) + .unwrap_or(1024 * 500) // 500k + } + + /// The number of retry of a task, default is 5 + fn retries(&self) -> u64 { + self.matches + .value_of("retries") + .map(|i| i.parse::().unwrap()) + .unwrap_or(5) + } + + /// The internal of each retry, default is zero + fn retry_wait(&self) -> u64 { + self.matches + .value_of("retry_wait") + .map(|i| i.parse::().unwrap()) + .unwrap_or(0) + } + + /// Task type + fn task_type(&self) -> TaskType { + self.matches + .value_of("type") + .map(|i| match i.to_lowercase().as_str() { + "http" => TaskType::HTTP, + "m3u8" => TaskType::M3U8, + _ => panic!(format!("{:?}", Error::UnsupportedTask(i.to_string()))), + }) + .unwrap_or(TaskType::HTTP) + } + + /// To debug mode, if it returns true + fn debug(&self) -> bool { + self.matches.is_present("debug") + } + + /// To quiet mode, if it return true + fn quiet(&self) -> bool { + self.matches.is_present("quiet") + } +} + +impl fmt::Debug for CmdArgs { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CmdArgs") + .field("output", &self.output()) + .field("method", &self.method()) + .field("uri", &self.uri()) + .field("data", &self.data()) + .field("headers", &self.headers()) + .field("proxy", &self.proxy()) + .field("timeout", &self.timeout()) + .field("concurrency", &self.concurrency()) + .field("chunk_length", &self.chunk_length()) + .field("retries", &self.retries()) + .field("retry_wait", &self.retry_wait()) + .field("task_type", &self.task_type()) + .field("debug", &self.debug()) + .field("quiet", &self.quiet()) + .finish() + } +} diff --git a/src/arguments/mod.rs b/src/arguments/mod.rs new file mode 100644 index 0000000..351f335 --- /dev/null +++ b/src/arguments/mod.rs @@ -0,0 +1,2 @@ +pub mod clap_app; +pub mod cmd_args; diff --git a/src/chunk.rs b/src/chunk.rs deleted file mode 100644 index c0ea516..0000000 --- a/src/chunk.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -#[derive(Debug, Clone)] -pub struct RangePart { - pub start: u64, - pub end: u64, -} - -impl RangePart { - pub fn new(start: u64, end: u64) -> RangePart { - RangePart { start, end } - } - - pub fn length(&self) -> u64 { - self.end - self.start + 1 - } -} - -pub type RangeStack = Rc>>; - -/// Split a close `interval` to many piece chunk that its size is equal to `chunk_length`, -/// but the last piece size can be less then `chunk_length`. -pub fn make_range_chunks(interval: &RangePart, chunk_length: u64) -> Vec { - let mut stack = Vec::new(); - - let mut start = interval.start; - let interval_end = interval.end; - - while start + chunk_length - 1 <= interval_end { - let end = start + chunk_length - 1; - stack.push(RangePart::new(start, end)); - start += chunk_length; - } - - if start <= interval_end { - stack.push(RangePart::new(start, interval_end)); - } - - stack -} diff --git a/src/common.rs b/src/common.rs deleted file mode 100644 index e16d9ec..0000000 --- a/src/common.rs +++ /dev/null @@ -1 +0,0 @@ -pub const AGET_EXT: &'static str = ".aget"; diff --git a/src/common/buf.rs b/src/common/buf.rs new file mode 100644 index 0000000..a979a83 --- /dev/null +++ b/src/common/buf.rs @@ -0,0 +1,2 @@ +/// Default buffer size +pub const SIZE: usize = 2048; diff --git a/src/common/bytes/bytes.rs b/src/common/bytes/bytes.rs new file mode 100644 index 0000000..bc70d79 --- /dev/null +++ b/src/common/bytes/bytes.rs @@ -0,0 +1,25 @@ +use std::num::ParseIntError; + +use crate::common::errors::Result; + +/// Create an integer value from its representation as a byte array in big endian. +pub fn u8x8_to_u64(u8x8: &[u8; 8]) -> u64 { + u64::from_be_bytes(*u8x8) +} + +/// Return the memory representation of this integer as a byte array in big-endian (network) byte +/// order. +pub fn u64_to_u8x8(u: u64) -> [u8; 8] { + u.to_be_bytes() +} + +pub fn u32_to_u8x4(u: u32) -> [u8; 4] { + u.to_be_bytes() +} + +pub fn decode_hex(s: &str) -> Result, ParseIntError> { + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16)) + .collect() +} diff --git a/src/common/bytes/bytes_type.rs b/src/common/bytes/bytes_type.rs new file mode 100644 index 0000000..1bda7da --- /dev/null +++ b/src/common/bytes/bytes_type.rs @@ -0,0 +1 @@ +pub use bytes::{Buf, BufMut, Bytes, BytesMut}; diff --git a/src/common/bytes/mod.rs b/src/common/bytes/mod.rs new file mode 100644 index 0000000..2dcdc9b --- /dev/null +++ b/src/common/bytes/mod.rs @@ -0,0 +1,2 @@ +pub mod bytes; +pub mod bytes_type; diff --git a/src/common/character.rs b/src/common/character.rs new file mode 100644 index 0000000..80a8702 --- /dev/null +++ b/src/common/character.rs @@ -0,0 +1,24 @@ +use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; + +const FRAGMENT: &AsciiSet = &CONTROLS + .add(b' ') + .remove(b'?') + .remove(b'/') + .remove(b':') + .remove(b'='); + +pub fn escape_nonascii(target: &str) -> String { + utf8_percent_encode(target, FRAGMENT).to_string() +} + +#[cfg(test)] +mod tests { + use super::escape_nonascii; + + #[test] + fn test_escape_nonascii() { + let s = ":ss/s 来;】/ 【【 ? 是的 & 水电费=45 进来看"; + println!("{}", s); + println!("{}", escape_nonascii(s)); + } +} diff --git a/src/common/colors.rs b/src/common/colors.rs new file mode 100644 index 0000000..db103b7 --- /dev/null +++ b/src/common/colors.rs @@ -0,0 +1 @@ +pub use ansi_term::Colour::*; diff --git a/src/common/crypto.rs b/src/common/crypto.rs new file mode 100644 index 0000000..08e774a --- /dev/null +++ b/src/common/crypto.rs @@ -0,0 +1,8 @@ +use openssl::symm::{decrypt, Cipher}; + +use crate::common::errors::Result; + +pub fn decrypt_aes128(key: &[u8], iv: &[u8], enc: &[u8]) -> Result> { + let cipher = Cipher::aes_128_cbc(); + Ok(decrypt(cipher, key, Some(iv), enc)?) +} diff --git a/src/common/debug.rs b/src/common/debug.rs new file mode 100644 index 0000000..57d7ce9 --- /dev/null +++ b/src/common/debug.rs @@ -0,0 +1,27 @@ +pub static mut DEBUG: bool = false; +pub static mut QUIET: bool = false; + +#[macro_export] +macro_rules! print_err { + ( $ctx:expr, $err:expr ) => { + eprintln!("[{}:{}] {}: {}", file!(), line!(), $ctx, $err); + }; +} + +#[macro_export] +macro_rules! debug { + ( $title:expr, $msg:expr ) => { + unsafe { + if crate::common::debug::DEBUG { + eprintln!("[{}:{}] {}: {:#?}", file!(), line!(), $title, $msg); + } + } + }; + ( $title:expr ) => { + unsafe { + if crate::common::debug::DEBUG { + eprintln!("[{}:{}] {}", file!(), line!(), $title); + } + } + }; +} diff --git a/src/common/errors.rs b/src/common/errors.rs new file mode 100644 index 0000000..8b7f809 --- /dev/null +++ b/src/common/errors.rs @@ -0,0 +1,91 @@ +use std::{io::Error as IoError, num, result}; + +use thiserror::Error as ThisError; + +use isahc::{http, Error as IsahcError}; + +use url::ParseError as UrlParseError; + +use openssl; + +pub type Result = result::Result; + +#[derive(Debug, ThisError)] +pub enum Error { + // For Arguments + #[error("Output path is invalid: {0}")] + InvalidPath(String), + #[error("Uri is invalid: {0}")] + InvaildUri(#[from] http::uri::InvalidUri), + #[error("Header is invalid: {0}")] + InvalidHeader(String), + #[error("No filename.")] + NoFilename, + #[error("Directory is not found")] + NotFoundDirectory, + #[error("The file already exists.")] + FileExists, + #[error("The path is a directory.")] + PathIsDirectory, + #[error("Can't parse string as number: {0}")] + IsNotNumber(#[from] num::ParseIntError), + #[error("Io Error: {0}")] + Io(#[from] IoError), + #[error("{0} task is not supported")] + UnsupportedTask(String), + + // For Network + #[error("Network error: {0}")] + NetError(String), + #[error("Uncompleted Read")] + UncompletedRead, + #[error("{0} is unsupported")] + UnsupportedMethod(String), + #[error("header is invalid: {0}")] + HeaderParseError(String), + #[error("header is invalid: {0}")] + UrlParseError(#[from] UrlParseError), + #[error("BUG: {0}")] + Bug(String), + #[error("The two content lengths are not equal between the response and the aget file.")] + ContentLengthIsNotConsistent, + + // For m3u8 + #[error("Fail to parse m3u8 file.")] + M3U8ParseFail, + #[error("The two m3u8 parts are not equal between the response and the aget file.")] + PartsAreNotConsistent, + + #[error("An internal error: {0}")] + InnerError(String), + #[error("Content does not has length")] + NoContentLength, + #[error("header is invalid: {0}")] + InvaildHeader(String), + #[error("response status code is: {0}")] + Unsuccess(u16), + #[error("Redirect to: {0}")] + Redirect(String), + #[error("No Location for redirection: {0}")] + NoLocation(String), + #[error("Fail to decrypt aes128 data: {0}")] + AES128DecryptFail(#[from] openssl::error::ErrorStack), +} + +impl From for Error { + fn from(err: IsahcError) -> Error { + Error::NetError(format!("{}", err)) + } +} + +impl From for Error { + fn from(err: http::header::ToStrError) -> Error { + Error::NetError(format!("{}", err)) + } +} + +impl From for Error { + fn from(err: http::Error) -> Error { + Error::NetError(format!("{}", err)) + } +} diff --git a/src/common/file.rs b/src/common/file.rs new file mode 100644 index 0000000..1b541fe --- /dev/null +++ b/src/common/file.rs @@ -0,0 +1,107 @@ +use std::{ + fs::{create_dir_all, metadata, remove_file, File as StdFile, OpenOptions}, + io::{Read, Seek, SeekFrom, Write}, + path::{Path, PathBuf}, +}; + +use crate::common::errors::{Error, Result}; + +/// File can be readed or writen only by opened. +pub struct File { + path: PathBuf, + file: Option, + readable: bool, +} + +impl File { + pub fn new>(path: P, readable: bool) -> Result { + let path = path.as_ref().to_path_buf(); + if path.is_dir() { + return Err(Error::InvalidPath(format!("{:?}", path))); + } + + Ok(File { + path, + file: None, + readable, + }) + } + + /// Create the dir if it does not exists + fn create_dir>(&self, dir: P) -> Result<()> { + if !dir.as_ref().exists() { + Ok(create_dir_all(dir)?) + } else { + Ok(()) + } + } + + /// Create or open the file + pub fn open(&mut self) -> Result<&mut Self> { + if let Some(dir) = self.path.parent() { + self.create_dir(dir)?; + } + let file = OpenOptions::new() + .read(self.readable) + .write(true) + .truncate(false) + .create(true) + .open(self.path.as_path())?; + self.file = Some(file); + Ok(self) + } + + pub fn exists(&self) -> bool { + self.path.as_path().exists() + } + + pub fn file_name(&self) -> Option<&str> { + if let Some(n) = self.path.as_path().file_name() { + n.to_str() + } else { + None + } + } + + pub fn file(&mut self) -> Result<&mut StdFile> { + if let Some(ref mut file) = self.file { + Ok(file) + } else { + Err(Error::Bug("`store::File::file` must be opened".to_string())) + } + } + + pub fn size(&self) -> u64 { + if let Ok(md) = metadata(&self.path) { + md.len() + } else { + 0 + } + } + + pub fn write(&mut self, buf: &[u8], seek: Option) -> Result { + if let Some(seek) = seek { + self.seek(seek)?; + } + Ok(self.file()?.write(buf)?) + } + + pub fn read(&mut self, buf: &mut [u8], seek: Option) -> Result { + if let Some(seek) = seek { + self.seek(seek)?; + } + Ok(self.file()?.read(buf)?) + } + + pub fn seek(&mut self, seek: SeekFrom) -> Result { + Ok(self.file()?.seek(seek)?) + } + + pub fn set_len(&mut self, size: u64) -> Result<()> { + Ok(self.file()?.set_len(size)?) + } + + pub fn remove(&self) -> Result<()> { + Ok(remove_file(self.path.as_path())?) + } +} diff --git a/src/common/liberal.rs b/src/common/liberal.rs new file mode 100644 index 0000000..5121a10 --- /dev/null +++ b/src/common/liberal.rs @@ -0,0 +1,54 @@ +use crate::common::errors::{Error, Result}; + +const SIZES: [&'static str; 5] = ["B", "K", "M", "G", "T"]; + +/// Convert liberal number to u64 +/// e.g. +/// 100k -> 100 * 1024 +pub trait ParseLiteralNumber { + fn literal_number(&self) -> Result; +} + +impl ParseLiteralNumber for &str { + fn literal_number(&self) -> Result { + let (num, unit) = self.split_at(self.len() - 1); + if unit.parse::().is_err() { + let mut num = num.parse::()?; + for s in &SIZES { + if s == &unit.to_uppercase() { + return Ok(num); + } else { + num *= 1024; + } + } + Ok(num) + } else { + let num = self.parse::()?; + Ok(num) + } + } +} + +/// Convert seconds to date format +pub trait ToDate { + fn date(&self) -> String; +} + +impl ToDate for u64 { + fn date(&self) -> String { + let mut num = *self as f64; + if num < 60.0 { + return format!("{:.0}s", num); + } + num /= 60.0; + if num < 60.0 { + return format!("{:.0}m", num); + } + num /= 60.0; + if num < 24.0 { + return format!("{:.0}h", num); + } + num /= 24.0; + return format!("{:.0}d", num); + } +} diff --git a/src/common/list.rs b/src/common/list.rs new file mode 100644 index 0000000..430a845 --- /dev/null +++ b/src/common/list.rs @@ -0,0 +1,30 @@ +use std::sync::{Arc, Mutex}; + +use crate::features::stack::StackLike; + +#[derive(Debug, Clone)] +pub struct SharedVec { + inner: Arc>>, +} + +impl SharedVec { + pub fn new(list: Vec) -> SharedVec { + SharedVec { + inner: Arc::new(Mutex::new(list)), + } + } +} + +impl StackLike for SharedVec { + fn push(&mut self, item: T) { + self.inner.lock().unwrap().push(item) + } + + fn pop(&mut self) -> Option { + self.inner.lock().unwrap().pop() + } + + fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..19e174b --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,18 @@ +#[macro_use] +pub mod debug; + +pub mod buf; +pub mod bytes; +pub mod character; +pub mod colors; +pub mod crypto; +pub mod errors; +pub mod file; +pub mod liberal; +pub mod list; +pub mod net; +pub mod range; +pub mod size; +pub mod tasks; +pub mod terminal; +pub mod uri; diff --git a/src/common/net/mod.rs b/src/common/net/mod.rs new file mode 100644 index 0000000..2567161 --- /dev/null +++ b/src/common/net/mod.rs @@ -0,0 +1,2 @@ +pub mod net; +pub mod net_type; diff --git a/src/common/net/net.rs b/src/common/net/net.rs new file mode 100644 index 0000000..572b9c5 --- /dev/null +++ b/src/common/net/net.rs @@ -0,0 +1,184 @@ +use std::time; + +use crate::common::{ + bytes::bytes_type::Bytes, + errors::{Error, Result}, + net::net_type::{ + header, Body, Configurable, ContentLengthValue, HttpClient, Method, Request, Response, Uri, + }, + range::RangePair, +}; + +pub fn parse_header(raw: &str) -> Result<(&str, &str), Error> { + if let Some(index) = raw.find(": ") { + return Ok((&raw[..index], &raw[index + 2..])); + } + if let Some(index) = raw.find(":") { + return Ok((&raw[..index], &raw[index + 1..])); + } + Err(Error::InvalidHeader(raw.to_string())) +} + +pub fn parse_headers<'a, I: IntoIterator>( + raws: I, +) -> Result, Error> { + let mut headers = vec![]; + for raw in raws { + let pair = parse_header(raw)?; + headers.push(pair); + } + Ok(headers) +} + +/// Builder a http client of curl +pub fn build_http_client( + headers: &[(&str, &str)], + timeout: u64, + proxy: Option<&str>, +) -> Result { + let mut builder = HttpClient::builder().default_headers(headers).proxy({ + if proxy.is_some() { + Some(proxy.unwrap().parse()?) + } else { + None + } + }); + // If timeout is zero, no timeout will be enforced. + if timeout > 0 { + builder = builder.timeout(time::Duration::from_secs(timeout)); + } + let client = builder.cookies().build()?; + Ok(client) +} + +// TODO +struct RequestInfo { + method: Method, + uri: String, + headers: Vec<(String, String)>, + // Post data + data: Option, + timeout: u64, + proxy: Option, +} + +impl RequestInfo {} + +/// Check whether the response is success +pub fn is_success(resp: &Response) -> Result<(), Error> { + let status = resp.status(); + if !status.is_success() { + Err(Error::Unsuccess(status.as_u16())) + } else { + Ok(()) + } +} + +/// Send a request with a range header, returning the final uri +pub async fn redirect( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, +) -> Result { + let mut uri = uri; + loop { + let data = data.clone().map(|d| Body::from_bytes(&d)); + let request = Request::builder() + .method(method.clone()) + .uri(uri.clone()) + .header(header::RANGE, "bytes=0-1") + .body(data)?; + let resp = client.send_async(request).await?; + // is_success(&resp)?; + if !resp.status().is_redirection() { + break; + } + let headers = resp.headers(); + if let Some(location) = headers.get(header::LOCATION) { + uri = location.to_str()?.parse()?; + } else { + break; + } + } + Ok(uri) +} + +/// Get the content length of the resource +pub async fn content_length( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, +) -> Result { + let mut uri = uri; + loop { + let data = data.clone().map(|d| Body::from_bytes(&d)); + let request = Request::builder() + .method(method.clone()) + .uri(uri.clone()) + .header(header::RANGE, "bytes=0-1") + .body(data)?; + let resp = client.send_async(request).await?; + is_success(&resp)?; + let headers = resp.headers(); + if resp.status().is_redirection() { + if let Some(location) = headers.get(header::LOCATION) { + uri = location.to_str()?.parse()?; + continue; + } else { + return Err(Error::NoLocation(format!("{}", uri))); + } + } else { + if let Some(h) = headers.get(header::CONTENT_RANGE) { + if let Ok(s) = h.to_str() { + if let Some(index) = s.find("/") { + if let Ok(length) = &s[index + 1..].parse::() { + return Ok(ContentLengthValue::RangeLength(length.clone())); + } + } + } + } + if let Some(h) = resp.headers().get(header::CONTENT_LENGTH) { + if let Ok(s) = h.to_str() { + if let Ok(length) = s.parse::() { + return Ok(ContentLengthValue::DirectLength(length.clone())); + } + } + } + break; + } + } + Ok(ContentLengthValue::NoLength) +} + +/// Send a request +pub async fn request( + client: &HttpClient, + method: Method, + uri: Uri, + data: Option, + range: Option, +) -> Result> { + let mut uri = uri; + loop { + let data = data.clone().map(|d| Body::from_bytes(&d)); + let mut builder = Request::builder().method(method.clone()).uri(uri.clone()); + + if let Some(RangePair { begin, end }) = range { + builder = builder.header(header::RANGE, &format!("bytes={}-{}", begin, end)); + } + let request = builder.body(data)?; + let resp = client.send_async(request).await?; + is_success(&resp)?; + if resp.status().is_redirection() { + if let Some(location) = resp.headers().get(header::LOCATION) { + uri = location.to_str()?.parse()?; + continue; + } else { + return Err(Error::NoLocation(format!("{}", uri))); + } + } + return Ok(resp); + } +} diff --git a/src/common/net/net_type.rs b/src/common/net/net_type.rs new file mode 100644 index 0000000..932cfd5 --- /dev/null +++ b/src/common/net/net_type.rs @@ -0,0 +1,16 @@ +pub use isahc::{ + self, + config::Configurable, + http, + http::{header, HeaderMap, HeaderValue, Method, Request, Response, Uri}, + Body, HttpClient, RequestExt, ResponseExt, +}; + +pub use url::Url; + +#[derive(Debug)] +pub enum ContentLengthValue { + RangeLength(u64), + DirectLength(u64), + NoLength, +} diff --git a/src/common/range.rs b/src/common/range.rs new file mode 100644 index 0000000..99970c5 --- /dev/null +++ b/src/common/range.rs @@ -0,0 +1,70 @@ +use std::sync::{Arc, Mutex}; + +use crate::features::stack::StackLike; + +#[derive(Debug, Clone, Copy)] +pub struct RangePair { + pub begin: u64, + pub end: u64, +} + +impl RangePair { + pub fn new(begin: u64, end: u64) -> RangePair { + RangePair { begin, end } + } + + // The length of a `RangePair` is the closed interval length + pub fn length(&self) -> u64 { + self.end - self.begin + 1 + } +} + +pub type RangeList = Vec; + +#[derive(Debug, Clone)] +pub struct SharedRangList { + inner: Arc>, +} + +impl SharedRangList { + pub fn new(rangelist: RangeList) -> SharedRangList { + SharedRangList { + inner: Arc::new(Mutex::new(rangelist)), + } + } +} + +impl StackLike for SharedRangList { + fn push(&mut self, pair: RangePair) { + self.inner.lock().unwrap().push(pair) + } + + fn pop(&mut self) -> Option { + self.inner.lock().unwrap().pop() + } + + fn len(&self) -> usize { + self.inner.lock().unwrap().len() + } +} + +/// Split a close `RangePair` to many piece of pairs that each of their size is equal to +/// `chunk_size`, but the last piece size can be less then `chunk_size`. +pub fn split_pair(pair: &RangePair, chunk_size: u64) -> RangeList { + let mut stack = Vec::new(); + + let mut begin = pair.begin; + let interval_end = pair.end; + + while begin + chunk_size - 1 <= interval_end { + let end = begin + chunk_size - 1; + stack.push(RangePair::new(begin, end)); + begin += chunk_size; + } + + if begin <= interval_end { + stack.push(RangePair::new(begin, interval_end)); + } + + stack +} diff --git a/src/common/size.rs b/src/common/size.rs new file mode 100644 index 0000000..a427286 --- /dev/null +++ b/src/common/size.rs @@ -0,0 +1,34 @@ +/// Here, we handle about size. + +const SIZES: [&'static str; 5] = ["B", "K", "M", "G", "T"]; + +/// Convert number to human-readable +pub trait HumanReadable { + fn human_readable(&self) -> String; +} + +impl HumanReadable for u64 { + fn human_readable(&self) -> String { + let mut num = *self as f64; + for s in &SIZES { + if num < 1024.0 { + return format!("{:.1}{}", num, s); + } + num /= 1024.0; + } + return format!("{:.1}{}", num, SIZES[SIZES.len() - 1]); + } +} + +impl HumanReadable for f64 { + fn human_readable(&self) -> String { + let mut num = *self; + for s in &SIZES { + if num < 1024.0 { + return format!("{:.1}{}", num, s); + } + num /= 1024.0; + } + return format!("{:.1}{}", num, SIZES[SIZES.len() - 1]); + } +} diff --git a/src/common/tasks.rs b/src/common/tasks.rs new file mode 100644 index 0000000..33924d7 --- /dev/null +++ b/src/common/tasks.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone)] +pub enum TaskType { + HTTP, + M3U8, +} diff --git a/src/common/terminal.rs b/src/common/terminal.rs new file mode 100644 index 0000000..753b174 --- /dev/null +++ b/src/common/terminal.rs @@ -0,0 +1,13 @@ +use term_size::dimensions; + +const MIN_TERMINAL_WIDTH: u64 = 60; + +pub fn terminal_width() -> u64 { + if let Some((width, _)) = dimensions() { + width as u64 + } else { + // for envrionment in which atty is not available, + // example, at ci of osx + MIN_TERMINAL_WIDTH + } +} diff --git a/src/common/uri.rs b/src/common/uri.rs new file mode 100644 index 0000000..28f5c4d --- /dev/null +++ b/src/common/uri.rs @@ -0,0 +1,22 @@ +use std::path::Path; + +use crate::common::{ + errors::{Error, Result}, + net::net_type::Uri, +}; + +/// Use the last of components of uri as a file name +pub trait UriFileName { + fn file_name(&self) -> Result<&str>; +} + +impl UriFileName for Uri { + fn file_name(&self) -> Result<&str> { + let path = Path::new(self.path()); + if let Some(file_name) = path.file_name() { + Ok(file_name.to_str().unwrap()) + } else { + Err(Error::NoFilename) + } + } +} diff --git a/src/core.rs b/src/core.rs deleted file mode 100644 index 0d8919a..0000000 --- a/src/core.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::{cell::RefCell, cmp::min, io::SeekFrom, rc::Rc, time::Duration}; - -use futures::{ - channel::mpsc::{channel, Receiver}, - select, - stream::StreamExt, -}; - -use actix_rt::{spawn, time::interval, System}; -use bytes::Bytes; - -use crate::{ - app::Config, - chunk::{make_range_chunks, RangePart, RangeStack}, - error::{AgetError, NetError, Result}, - printer::Printer, - request::{get_content_length, get_redirect_uri, AgetRequestOptions, ContentLengthItem}, - store::{AgetFile, File, TaskInfo}, - task::RequestTask, - util::QUIET, -}; - -pub struct CoreProcess { - config: Config, - options: AgetRequestOptions, - range_stack: RangeStack, - // the length of range_stack - range_count: u64, -} - -impl CoreProcess { - pub fn new(config: Config) -> Result { - let headers = config - .headers - .iter() - .map(AsRef::as_ref) - .collect::>(); - let data = config.data.as_ref().map(AsRef::as_ref); - let options = - AgetRequestOptions::new(&config.uri, &config.method, &headers, data, config.timeout)?; - - Ok(CoreProcess { - config, - options, - range_stack: Rc::new(RefCell::new(vec![])), - range_count: 1, - }) - } - - fn check_content_length(&self, content_length: u64) -> Result<()> { - debug!("Check content length", content_length); - let mut aget_file = AgetFile::new(&self.config.path)?; - if aget_file.exists() { - aget_file.open()?; - if content_length != aget_file.content_length()? { - debug!( - "!! the content length that response returned isn't equal of aget file", - format!("{} != {}", content_length, aget_file.content_length()?) - ); - return Err(AgetError::ContentLengthIsNotConsistent.into()); - } - } - debug!("Check content length: equal"); - - Ok(()) - } - - fn set_content_length(&self, content_length: u64) -> Result<()> { - debug!("Set content length", content_length); - let mut aget_file = AgetFile::new(&self.config.path)?; - if !aget_file.exists() { - aget_file.open()?; - aget_file.write_content_length(content_length)?; - } else { - aget_file.open()?; - aget_file.rewrite()?; - } - Ok(()) - } - - fn make_range_stack(&mut self) -> Result<()> { - debug!("Make range stack"); - - let mut range_stack: Vec = Vec::new(); - - if self.options.is_concurrent() { - let mut aget_file = AgetFile::new(&self.config.path)?; - aget_file.open()?; - let gaps = aget_file.gaps()?; - - let chunk_length = self.config.chunk_length; - for gap in gaps.iter() { - let mut list = make_range_chunks(gap, chunk_length); - range_stack.append(&mut list); - } - range_stack.reverse(); - } else { - range_stack.push(RangePart::new(0, 0)); - }; - - self.range_count = range_stack.len() as u64; - - debug!("Range stack size", range_stack.len()); - self.range_stack = Rc::new(RefCell::new(range_stack)); - - Ok(()) - } - - pub async fn run(&mut self) -> Result<()> { - // 1. Get redirected uri - debug!("Redirect task"); - get_redirect_uri(&mut self.options).await?; - - // 2. Get content length - debug!("ContentLength task"); - let cn_item = get_content_length(&mut self.options).await?; - match cn_item { - ContentLengthItem::RangeLength(content_length) => { - self.check_content_length(content_length)?; - self.set_content_length(content_length)?; - } - ContentLengthItem::DirectLength(content_length) => { - self.set_content_length(content_length)?; - self.options.no_concurrency(); - - // Let connector to be always alive - self.options.reset_connector(60, 60, 0); - } - ContentLengthItem::NoLength => { - return Err(NetError::NoContentLength.into()); - } - } - self.make_range_stack()?; - - // 3. Spawn concurrent tasks - let is_concurrent = self.options.is_concurrent(); - let (sender, receiver) = - channel::<(RangePart, Bytes)>((self.config.concurrency + 1) as usize); - let concurrency = if is_concurrent { - self.config.concurrency - } else { - 1 - }; - debug!("Spawn RequestTasks", concurrency); - - for i in 0..(min(self.config.concurrency, self.range_count)) { - debug!("RequestTask ", i); - let range_stack = self.range_stack.clone(); - let sender_ = sender.clone(); - let options = self.options.clone(); - let task = async { - let is_concurrent = options.is_concurrent(); - let mut request_task = RequestTask::new(range_stack, sender_); - let result = request_task.run(options).await; - if let Err(err) = result { - print_err!("RequestTask fails", err); - if !is_concurrent { - // Exit process when the only one request task fails - System::current().stop(); - } - } - }; - spawn(task); - } - - // 4. Wait stream handler - debug!("Start StreamHander"); - let concurrency = if is_concurrent { - self.config.concurrency - } else { - 0 - }; - let stream_header = StreamHander::new(&self.config.path, !is_concurrent, concurrency); - if stream_header.is_err() { - System::current().stop(); - } - let stream_header = stream_header.unwrap(); - stream_header.run(receiver).await; - - debug!("CoreProcess done"); - Ok(()) - } -} - -struct StreamHander { - file: File, - aget_file: AgetFile, - task_info: TaskInfo, - printer: Printer, - no_record: bool, - concurrency: u64, -} - -impl StreamHander { - fn new(path: &str, no_record: bool, concurrency: u64) -> Result { - let task_info = TaskInfo::new(path)?; - - let mut file = File::new(path, false)?; - file.open()?; - let mut aget_file = AgetFile::new(path)?; - aget_file.open()?; - - let printer = Printer::new(); - let mut handler = StreamHander { - file, - aget_file, - task_info, - printer, - no_record, - concurrency, - }; - handler.init_print()?; - Ok(handler) - } - - fn init_print(&mut self) -> Result<(), AgetError> { - unsafe { - if QUIET { - return Ok(()); - } - } - - if self.no_record { - self.printer - .print_msg("Server doesn't support range request.")?; - } - - let file_name = &self.task_info.path; - let content_length = self.task_info.content_length; - let concurrency = self.concurrency; - self.printer.print_header(file_name)?; - self.printer.print_length(content_length)?; - self.printer.print_concurrency(concurrency)?; - self.print_process()?; - Ok(()) - } - - fn print_process(&mut self) -> Result<(), AgetError> { - unsafe { - if QUIET { - return Ok(()); - } - } - - let total_length = self.task_info.content_length; - let completed_length = self.task_info.completed_length(); - let (rate, eta) = self.task_info.rate_and_eta(); - self.printer - .print_process(completed_length, total_length, rate, eta)?; - Ok(()) - } - - fn record_range(&mut self, range_part: RangePart) -> Result<(), ()> { - if self.no_record { - return Ok(()); - } - if let Err(err) = self.aget_file.write_interval(range_part) { - print_err!("write interval to aget file fails", err); - return Err(()); - } - return Ok(()); - } - - fn teardown(&mut self) -> Result<(), AgetError> { - self.aget_file.remove()?; - Ok(()) - } - - pub async fn run(mut self, receiver: Receiver<(RangePart, Bytes)>) { - debug!("StreamHander run"); - let mut tick = interval(Duration::from_secs(1)).fuse(); - let mut receiver = receiver.fuse(); - loop { - select! { - item = receiver.next() => { - if let Some((range, chunk)) = item { - let interval_length = range.length(); - - // write buf - if let Err(err) = self - .file - .write(&chunk[..], Some(SeekFrom::Start(range.start))) - { - print_err!("write chunk to file fails", err); - } - - // write range_part - self.record_range(range); - - // update `task_info` - self.task_info.add_completed(interval_length); - } else { - break; - } - } - _ = tick.next() => { - if let Err(err) = self.print_process() { - print_err!("print process fails", err); - } - self.task_info.clean_interval(); - if self.task_info.remains() == 0 { - if let Err(err) = self.print_process() { - print_err!("print process fails", err); - } - if let Err(err) = self.teardown() { - print_err!("teardown stream handler fails", err); - } - break; - } - } - } - } - } -} diff --git a/src/error.rs b/src/error.rs deleted file mode 100644 index 072fc4a..0000000 --- a/src/error.rs +++ /dev/null @@ -1,182 +0,0 @@ -use std::{fmt, io::Error as IoError, num, result}; - -use failure::{self, Backtrace, Fail}; -use futures::channel::mpsc::SendError; - -use awc::{ - self, - error::SendRequestError, - http, - http::{ - header::{InvalidHeaderName, InvalidHeaderValue, ToStrError}, - uri::InvalidUri, - }, -}; - -pub type Result = result::Result; - -pub struct Error { - cause: Box, - backtrace: Option, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - fmt::Display::fmt(&self.cause, f) - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(bt) = self.cause.backtrace() { - write!(f, "{:?}\n\n{:?}", &self.cause, bt) - } else { - write!( - f, - "{:?}\n\n{:?}", - &self.cause, - self.backtrace.as_ref().unwrap() - ) - } - } -} - -pub trait AgetFail: Fail {} - -impl From for Error { - fn from(err: T) -> Error { - let backtrace = if err.backtrace().is_none() { - Some(Backtrace::new()) - } else { - None - }; - Error { - cause: Box::new(err), - backtrace, - } - } -} - -#[derive(Fail, Debug)] -pub enum ArgError { - #[fail(display = "Output path is invalid: {}", _0)] - InvalidPath(String), - #[fail(display = "the uri is invalid: {}", _0)] - InvaildUri(String), - #[fail(display = "No filename.")] - NoFilename, - #[fail(display = "Directory is not found")] - NotFoundDirectory, - #[fail(display = "The file already exists.")] - FileExists, - #[fail(display = "The path is a directory.")] - PathIsDirectory, - #[fail(display = "Can't parse string as number: {}", _0)] - IsNotNumber(String), - #[fail(display = "Io Error: {}", _0)] - Io(#[cause] IoError), -} - -impl AgetFail for ArgError {} - -impl From for ArgError { - fn from(err: http::uri::InvalidUri) -> ArgError { - ArgError::InvaildUri(format!("{}", err)) - } -} - -impl From for ArgError { - fn from(err: num::ParseIntError) -> ArgError { - ArgError::IsNotNumber(format!("{}", err)) - } -} - -impl From for ArgError { - fn from(err: IoError) -> ArgError { - ArgError::Io(err) - } -} - -#[derive(Fail, Debug)] -pub enum AgetError { - #[fail(display = "Method is unsupported")] - UnsupportedMethod, - #[fail(display = "header is invalid: {}", _0)] - HeaderParseError(String), - #[fail(display = "BUG: {}", _0)] - Bug(String), - #[fail(display = "Path is invalid: {}", _0)] - InvalidPath(String), - #[fail(display = "Io Error: {}", _0)] - Io(#[cause] IoError), - #[fail(display = "No filename.")] - NoFilename, - #[fail(display = "The file already exists.")] - FileExists, - #[fail( - display = "The two content lengths are not equal between the response and the aget file." - )] - ContentLengthIsNotConsistent, -} - -impl AgetFail for AgetError {} - -impl From for AgetError { - fn from(err: IoError) -> AgetError { - AgetError::Io(err) - } -} - -#[derive(Fail, Debug)] -pub enum NetError { - #[fail(display = "an internal error: {}", _0)] - ActixError(String), - #[fail(display = "content does not has length")] - NoContentLength, - #[fail(display = "uri is invalid: {}", _0)] - InvaildUri(String), - #[fail(display = "header is invalid: {}", _0)] - InvaildHeader(String), - #[fail(display = "response status code is: {}", _0)] - Unsuccess(u16), - #[fail(display = "Redirect to: {}", _0)] - Redirect(String), -} - -impl AgetFail for NetError {} - -impl From for NetError { - fn from(err: SendRequestError) -> NetError { - NetError::ActixError(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: ToStrError) -> NetError { - NetError::ActixError(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: InvalidUri) -> NetError { - NetError::InvaildUri(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: InvalidHeaderName) -> NetError { - NetError::InvaildHeader(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: InvalidHeaderValue) -> NetError { - NetError::InvaildHeader(format!("{}", err)) - } -} - -impl From for NetError { - fn from(err: SendError) -> NetError { - NetError::ActixError(format!("{}", err)) - } -} diff --git a/src/features/args.rs b/src/features/args.rs new file mode 100644 index 0000000..9d1f689 --- /dev/null +++ b/src/features/args.rs @@ -0,0 +1,52 @@ +use std::path::PathBuf; + +use crate::common::{ + bytes::bytes_type::BytesMut, + net::net_type::{Method, Uri}, + tasks::TaskType, +}; + +/// This a arg which gives parameters for apps +pub trait Args { + /// Path of output + fn output(&self) -> PathBuf; + + /// Request method for http + fn method(&self) -> Method; + + /// The uri of a task + fn uri(&self) -> Uri; + + /// The data for http post request + fn data(&self) -> Option; + + /// Request headers + fn headers(&self) -> Vec<(String, String)>; + + /// Proxy: http, https, socks4, socks5 + fn proxy(&self) -> Option; + + /// The maximum time the request is allowed to take. + fn timeout(&self) -> u64; + + /// The number of concurrency + fn concurrency(&self) -> u64; + + /// The chunk length of each concurrency for http task + fn chunk_length(&self) -> u64; + + /// The number of retry of a task + fn retries(&self) -> u64; + + /// The internal of each retry + fn retry_wait(&self) -> u64; + + /// Task type + fn task_type(&self) -> TaskType; + + /// To debug mode, if it returns true + fn debug(&self) -> bool; + + /// To quiet mode, if it return true + fn quiet(&self) -> bool; +} diff --git a/src/features/mod.rs b/src/features/mod.rs new file mode 100644 index 0000000..2bed12f --- /dev/null +++ b/src/features/mod.rs @@ -0,0 +1,3 @@ +pub mod args; +pub mod running; +pub mod stack; diff --git a/src/features/running.rs b/src/features/running.rs new file mode 100644 index 0000000..f880277 --- /dev/null +++ b/src/features/running.rs @@ -0,0 +1,5 @@ +use crate::common::errors::Result; + +pub trait Runnable { + fn run(&mut self) -> Result<()>; +} diff --git a/src/features/stack.rs b/src/features/stack.rs new file mode 100644 index 0000000..c7acca7 --- /dev/null +++ b/src/features/stack.rs @@ -0,0 +1,7 @@ +pub trait StackLike { + fn push(&mut self, item: T); + + fn pop(&mut self) -> Option; + + fn len(&self) -> usize; +} diff --git a/src/main.rs b/src/main.rs index 3847294..8612fed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,93 +1,68 @@ -#![allow(unused_variables)] #![allow(dead_code)] -#![recursion_limit = "256"] - -use std::{process::exit, thread, time}; - -use actix_rt::{spawn, System}; #[macro_use] -mod util; +mod common; mod app; -mod chunk; -mod clap_app; -mod common; -mod core; -mod error; -mod printer; -mod request; -mod store; -mod task; +mod arguments; +mod features; -use crate::{ - app::App, - core::CoreProcess, - util::{DEBUG, QUIET}, -}; +use std::{thread, time::Duration}; -static mut SUCCESS: bool = false; +use app::core::{http::HttpHandler, m3u8::m3u8::M3u8Handler}; +use arguments::cmd_args::CmdArgs; +use common::{ + debug::{DEBUG, QUIET}, + tasks::TaskType, +}; +use features::{args::Args, running::Runnable}; fn main() { - let app = App::new(); - match app.config() { - Ok(config) => { - // set verbose - unsafe { - DEBUG = config.debug; - QUIET = config.quiet; - } - - debug!("Input configuration", &config); + let cmdargs = CmdArgs::new(); - let retry_wait = config.retry_wait; - let max_retries = config.max_retries; - for i in 0..(max_retries + 1) { - if i > 0 { - print_err!("!!! Retry", i); - thread::sleep(time::Duration::from_secs(retry_wait)); - } + // Set debug + if cmdargs.debug() { + unsafe { + DEBUG = true; + } + debug!("Args", cmdargs); + } - let sys = System::new("Aget"); + // Set quiet + if cmdargs.quiet() { + unsafe { + QUIET = true; + } + } - debug!("Make CoreProcess task"); + debug!("Main: begin"); - let config_ = config.clone(); - spawn(async move { - let core_process = CoreProcess::new(config_); - if let Ok(mut core_fut) = core_process { - let result = core_fut.run().await; - if let Err(err) = result { - print_err!("core_fut fails", err); - } else { - unsafe { - SUCCESS = true; - } - } - } else { - print_err!("core_fut error", ""); - exit(1); - } - debug!("main done"); - System::current().stop(); - }); + let tasktype = cmdargs.task_type(); + for i in 0..cmdargs.retries() + 1 { + if i != 0 { + println!("Retry {}", i); + } - if let Err(err) = sys.run() { - print_err!("System error", err); - } else { - // check task state - unsafe { - if SUCCESS { - break; - } - } - } + let result = match tasktype { + TaskType::HTTP => { + let mut httphandler = HttpHandler::new(&cmdargs).unwrap(); + httphandler.run() } - // debug!("!!! Can't be here"); - } - Err(err) => { - print_err!("app config fails", err); - exit(1); + TaskType::M3U8 => { + let mut m3u8handler = M3u8Handler::new(&cmdargs).unwrap(); + m3u8handler.run() + } + }; + + if let Err(err) = result { + print_err!("Error", err); + // Retry + let retrywait = cmdargs.retry_wait(); + thread::sleep(Duration::from_secs(retrywait)); + continue; + } else { + // Success + break; } } } diff --git a/src/printer.rs b/src/printer.rs deleted file mode 100644 index 29cf39f..0000000 --- a/src/printer.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::io::{stdout, Stdout, Write}; - -use ansi_term::{ - Colour::{Blue, Cyan, Green, Red, Yellow}, - Style, -}; - -use crate::{ - error::{AgetError, Result}, - util::{terminal_width, SizeOfFmt, TimeOfFmt}, -}; - -pub struct Printer { - colors: Colors, - terminal_width: u64, - stdout: Stdout, -} - -impl Printer { - pub fn new() -> Printer { - let terminal_width = terminal_width(); - Printer { - colors: Colors::colored(), - terminal_width, - stdout: stdout(), - } - } - - pub fn print_msg(&mut self, msg: &str) -> Result<(), AgetError> { - writeln!(&mut self.stdout, "\n {}", self.colors.msg.paint(msg))?; - Ok(()) - } - - pub fn print_header(&mut self, path: &str) -> Result<(), AgetError> { - writeln!( - &mut self.stdout, - "\n {}: {}", - self.colors.file_header.paint(" File"), - path, - )?; - Ok(()) - } - - pub fn print_length(&mut self, content_length: u64) -> Result<(), AgetError> { - writeln!( - &mut self.stdout, - " {}: {} ({})", - self.colors.content_length_header.paint("Length"), - content_length.sizeof_fmt(), - content_length, - )?; - Ok(()) - } - - pub fn print_concurrency(&mut self, concurrency: u64) -> Result<(), AgetError> { - writeln!( - &mut self.stdout, - "{}: {}\n", - self.colors.concurrency_header.paint("concurrency"), - concurrency, - )?; - Ok(()) - } - - pub fn print_process( - &mut self, - completed_length: u64, - total_length: u64, - rate: f64, - eta: u64, - ) -> Result<(), AgetError> { - let percent = completed_length as f64 / total_length as f64; - - let completed_length_str = completed_length.sizeof_fmt(); - let total_length_str = total_length.sizeof_fmt(); - let percent_str = format!("{:.2}", percent * 100.0); - let rate_str = rate.sizeof_fmt(); - let eta_str = eta.timeof_fmt(); - - // maximum info length is 41 e.g. - // 1001.3k/1021.9m 97.98% 1003.1B/s eta: 12s - let info = format!( - "{completed_length}/{total_length} {percent}% {rate}/s eta: {eta}", - completed_length = completed_length_str, - total_length = total_length_str, - percent = percent_str, - rate = rate_str, - eta = eta_str, - ); - - // set default info length - let info_length = 41; - let miss = info_length - info.len(); - - let bar_length = self.terminal_width - info_length as u64 - 4; - let process_bar_length = (bar_length as f64 * percent) as u64; - let blank_length = bar_length - process_bar_length; - - let process_bar_str = if process_bar_length > 0 { - format!("{}>", "=".repeat((process_bar_length - 1) as usize)) - } else { - "".to_owned() - }; - let blank_str = " ".repeat(blank_length as usize); - - write!( - &mut self.stdout, - "\r{completed_length}/{total_length} {percent}% {rate}/s eta: {eta}{miss} [{process_bar}{blank}] ", - completed_length = self.colors.completed_length.paint(completed_length_str), - total_length = self.colors.total_length.paint(total_length_str), - percent = self.colors.percent.paint(percent_str), - rate = self.colors.rate.paint(rate_str), - eta = self.colors.eta.paint(eta_str), - miss = " ".repeat(miss), - process_bar = process_bar_str, - blank = blank_str, - )?; - - self.stdout.flush()?; - - Ok(()) - } -} - -#[derive(Default)] -pub struct Colors { - pub file_header: Style, - pub content_length_header: Style, - pub concurrency_header: Style, - pub completed_length: Style, - pub total_length: Style, - pub percent: Style, - pub rate: Style, - pub eta: Style, - pub msg: Style, -} - -impl Colors { - pub fn plain() -> Colors { - Colors::default() - } - - pub fn colored() -> Colors { - Colors { - file_header: Green.bold(), - content_length_header: Blue.bold(), - completed_length: Red.bold(), - concurrency_header: Yellow.bold(), - total_length: Green.bold(), - percent: Yellow.bold(), - rate: Blue.bold(), - eta: Cyan.bold(), - msg: Yellow.italic(), - } - } -} diff --git a/src/request.rs b/src/request.rs deleted file mode 100644 index e9d70d4..0000000 --- a/src/request.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::time::Duration; - -use awc::{ - http::{header, Method, Uri}, - Client, ClientBuilder, ClientRequest, Connector, -}; - -use clap::crate_version; - -use crate::error::{AgetError, NetError, Result}; - -fn parse_header(raw: &str) -> Result<(&str, &str), AgetError> { - if let Some(index) = raw.find(": ") { - return Ok((&raw[..index], &raw[index + 2..])); - } - if let Some(index) = raw.find(":") { - return Ok((&raw[..index], &raw[index + 1..])); - } - Err(AgetError::HeaderParseError(raw.to_string())) -} - -#[derive(Clone)] -pub struct AgetRequestOptions { - uri: String, - method: Method, - headers: Vec<(String, String)>, - body: Option, - concurrent: bool, - client: Client, -} - -impl AgetRequestOptions { - pub fn new( - uri: &str, - method: &str, - headers: &[&str], - body: Option<&str>, - timeout: u64, - ) -> Result { - let method = match method.to_uppercase().as_str() { - "GET" => Method::GET, - "POST" => Method::POST, - _ => return Err(AgetError::UnsupportedMethod), - }; - - let mut header_list = Vec::new(); - for header in headers.iter() { - let (key, value) = parse_header(header)?; - header_list.push((key.to_string(), value.to_string())); - } - - let connector = Connector::new() - .limit(0) // no limit simultaneous connections. - .timeout(Duration::from_secs(timeout)) // DNS timeout - .conn_keep_alive(Duration::from_secs(60)) - .conn_lifetime(Duration::from_secs(0)) - .finish(); - let client = ClientBuilder::new().connector(connector).finish(); - - Ok(AgetRequestOptions { - method, - uri: uri.to_string(), - headers: header_list, - body: if let Some(body) = body { - Some(body.to_string()) - } else { - None - }, - concurrent: true, - client, - }) - } - - pub fn build(&self) -> Result { - // set user-agent if none - let aget_ua = format!("aget/{}", crate_version!()); - - let uri = self.uri.parse::()?; - let host = if let Some(host) = uri.host() { - host - } else { - return Err(NetError::InvaildUri(self.uri.to_string())); - }; - - let mut client_request = self - .client - .request(self.method.clone(), self.uri.clone()) - .set_header_if_none("User-Agent", aget_ua) // set default user-agent - .set_header_if_none("Accept", "*/*") // set accept if none - .set_header_if_none("Host", host); // set header `Host` - - for (ref key, ref val) in &self.headers { - client_request - .headers_mut() - .insert(key.as_str().parse()?, val.as_str().parse()?); - } - - Ok(client_request) - } - - pub fn uri(&self) -> String { - self.uri.clone() - } - - pub fn set_uri(&mut self, uri: &str) -> &mut Self { - self.uri = uri.to_string(); - self - } - - pub fn body(&self) -> &Option { - &self.body - } - - pub fn is_concurrent(&self) -> bool { - self.concurrent - } - - pub fn no_concurrency(&mut self) -> &mut Self { - self.concurrent = false; - self - } - - pub fn reset_connector(&mut self, timeout: u64, keep_alive: u64, lifetime: u64) -> &mut Self { - let connector = Connector::new() - .limit(0) // no limit simultaneous connections. - .timeout(Duration::from_secs(timeout)) // DNS timeout - .conn_keep_alive(Duration::from_secs(keep_alive)) - .conn_lifetime(Duration::from_secs(lifetime)) - .finish(); - self.client = ClientBuilder::new().connector(connector).finish(); - self - } -} - -/// Get redirected uri and reset `AgetRequestOptions.uri` -pub async fn get_redirect_uri(options: &mut AgetRequestOptions) -> Result<(), NetError> { - loop { - let client_request = options.build()?; - let resp = if let Some(body) = options.body() { - client_request.send_body(body).await? - } else { - client_request.send().await? - }; - let status = resp.status(); - if !(status.is_success() || status.is_redirection()) { - return Err(NetError::Unsuccess(status.as_u16())); - } else { - if status.is_redirection() { - if let Some(location) = resp.headers().get(header::LOCATION) { - options.set_uri(location.to_str()?); - } - } - return Ok(()); - } - } -} - -#[derive(Debug)] -pub enum ContentLengthItem { - RangeLength(u64), - DirectLength(u64), - NoLength, -} - -pub async fn get_content_length( - options: &mut AgetRequestOptions, -) -> Result { - let client_request = options.build()?.header(header::RANGE, "bytes=0-1"); - let resp = if let Some(body) = options.body() { - client_request.send_body(body).await? - } else { - client_request.send().await? - }; - - let status = resp.status(); - if !status.is_success() { - return Err(NetError::Unsuccess(status.as_u16())); - } else { - if let Some(h) = resp.headers().get(header::CONTENT_RANGE) { - if let Ok(s) = h.to_str() { - if let Some(index) = s.find("/") { - if let Ok(length) = &s[index + 1..].parse::() { - return Ok(ContentLengthItem::RangeLength(length.clone())); - } - } - } - } else { - if let Some(h) = resp.headers().get(header::CONTENT_LENGTH) { - if let Ok(s) = h.to_str() { - if let Ok(length) = s.parse::() { - return Ok(ContentLengthItem::DirectLength(length.clone())); - } - } - } - } - } - Ok(ContentLengthItem::NoLength) -} diff --git a/src/store.rs b/src/store.rs deleted file mode 100644 index e19c24c..0000000 --- a/src/store.rs +++ /dev/null @@ -1,362 +0,0 @@ -use std::{ - cmp::max, - fs::{remove_file, File as StdFile, OpenOptions}, - io::{Read, Seek, SeekFrom, Write}, - path::PathBuf, - time::Instant, -}; - -use crate::{ - chunk::RangePart, - common::AGET_EXT, - error::{AgetError, Result}, -}; - -pub struct TaskInfo { - pub path: String, - - /// The length of the file - pub content_length: u64, - - /// The length stored to the file - completed_length: u64, - - /// The stored length at an interval of one tick - interval_length: u64, - - /// The interval of one tick - tick_interval: Instant, -} - -impl TaskInfo { - pub fn new(path: &str) -> Result { - let mut aget_file = AgetFile::new(path)?; - aget_file.open()?; - - let path = path.to_string(); - let content_length = aget_file.content_length()?; - let completed_length = aget_file.completed_length()?; - - Ok(TaskInfo { - path, - content_length, - completed_length, - interval_length: 0, - tick_interval: Instant::now(), - }) - } - - pub fn completed_length(&self) -> u64 { - self.completed_length - } - - pub fn remains(&self) -> u64 { - self.content_length - self.completed_length - } - - pub fn rate_and_eta(&self) -> (f64, u64) { - let interval = self.tick_interval.elapsed().as_secs_f64(); - let rate = self.interval_length as f64 / interval; - let remains = self.remains(); - // rate > 1.0 for overflow - let eta = if remains > 0 && rate > 1.0 { - let eta = (remains as f64 / rate) as u64; - // eta is large than 99 days, return 0 - if eta > 99 * 24 * 60 * 60 { - 0 - } else { - eta - } - } else { - 0 - }; - (rate, eta) - } - - pub fn add_completed(&mut self, interval_length: u64) { - self.completed_length += interval_length; - self.interval_length += interval_length; - } - - pub fn clean_interval(&mut self) { - self.interval_length = 0; - self.tick_interval = Instant::now(); - } -} - -pub struct File { - path: PathBuf, - file: Option, - readable: bool, -} - -impl File { - pub fn new(p: &str, readable: bool) -> Result { - let path = PathBuf::from(p); - if path.is_dir() { - return Err(AgetError::InvalidPath(p.to_string())); - } - - Ok(File { - path, - file: None, - readable, - }) - } - - pub fn open(&mut self) -> Result<&mut Self, AgetError> { - let file = OpenOptions::new() - .read(self.readable) - .write(true) - .truncate(false) - .create(true) - .open(self.path.as_path())?; - self.file = Some(file); - Ok(self) - } - - pub fn exists(&self) -> bool { - self.path.as_path().exists() - } - - pub fn file_name(&self) -> Result { - if let Some(file_name) = self.path.as_path().file_name() { - Ok(file_name.to_str().unwrap().to_string()) - } else { - Err(AgetError::NoFilename) - } - } - - pub fn file(&mut self) -> Result<&mut StdFile, AgetError> { - if let Some(ref mut file) = self.file { - Ok(file) - } else { - Err(AgetError::Bug( - "`store::File::file` must be opened".to_string(), - )) - } - } - - pub fn write(&mut self, buf: &[u8], seek: Option) -> Result { - if let Some(seek) = seek { - self.seek(seek)?; - } - let s = self.file()?.write(buf)?; - Ok(s) - } - - pub fn read(&mut self, buf: &mut [u8], seek: Option) -> Result { - if let Some(seek) = seek { - self.seek(seek)?; - } - let s = self.file()?.read(buf)?; - Ok(s) - } - - pub fn seek(&mut self, seek: SeekFrom) -> Result { - let s = self.file()?.seek(seek)?; - Ok(s) - } - - pub fn set_len(&mut self, size: u64) -> Result<(), AgetError> { - self.file()?.set_len(size)?; - Ok(()) - } - - pub fn remove(&self) -> Result<(), AgetError> { - remove_file(self.path.as_path())?; - Ok(()) - } -} - -/// Aget Information Store File -/// -/// The file stores two kinds of information which of the downloading file. -/// 1. Content Length -/// The downloading file's content length. -/// If the number `request::ContentLength` returned is not equal to this content length, -/// the process will be terminated. -/// 2. Close Intervals of Downloaded Pieces -/// These intervals are pieces `(u64, u64)` stored as big-endian. -/// First item is the begin of header `Range`. -/// Second item is the end of header `Range`. -pub struct AgetFile { - inner: File, -} - -impl AgetFile { - pub fn new(path: &str) -> Result { - let mut path = path.to_string(); - path.push_str(AGET_EXT); - - let file = File::new(&path, true)?; - Ok(AgetFile { inner: file }) - } - - pub fn open(&mut self) -> Result<&mut Self, AgetError> { - self.inner.open()?; - Ok(self) - } - - pub fn file_name(&self) -> Result { - self.inner.file_name() - } - - pub fn exists(&self) -> bool { - self.inner.exists() - } - - pub fn remove(&self) -> Result<(), AgetError> { - self.inner.remove() - } - - /// Get downloading file's content length stored in the aget file - pub fn content_length(&mut self) -> Result { - let mut buf: [u8; 8] = [0; 8]; - self.inner.read(&mut buf, Some(SeekFrom::Start(0)))?; - let content_length = u8x8_to_u64(&buf); - Ok(content_length) - } - - pub fn completed_length(&mut self) -> Result { - let completed_intervals = self.completed_intervals()?; - if completed_intervals.is_empty() { - Ok(0) - } else { - Ok(completed_intervals.iter().map(RangePart::length).sum()) - } - } - - pub fn completed_intervals(&mut self) -> Result, AgetError> { - let mut intervals: Vec<(u64, u64)> = Vec::new(); - - let mut buf: [u8; 16] = [0; 16]; - self.inner.seek(SeekFrom::Start(8))?; - loop { - let s = self.inner.read(&mut buf, None)?; - if s != 16 { - break; - } - - let mut raw = [0; 8]; - raw.clone_from_slice(&buf[..8]); - let start = u8x8_to_u64(&raw); - raw.clone_from_slice(&buf[8..]); - let end = u8x8_to_u64(&raw); - - assert!( - start <= end, - format!( - "Bug: `start > end` in an interval of aget file. : {} > {}", - start, end - ) - ); - - intervals.push((start, end)); - } - - intervals.sort(); - - // merge intervals - let mut merge_intervals: Vec<(u64, u64)> = Vec::new(); - if !intervals.is_empty() { - merge_intervals.push(intervals[0]); - } - for (start, end) in intervals.iter() { - let (pre_start, pre_end) = merge_intervals.last().unwrap().clone(); - - // case 1 - // ---------- - // ----------- - if pre_end + 1 < *start { - merge_intervals.push((*start, *end)); - // case 2 - // ----------------- - // ---------- - // -------- - // ------ - } else { - let n_start = pre_start; - let n_end = max(pre_end, *end); - merge_intervals.pop(); - merge_intervals.push((n_start, n_end)); - } - } - - Ok(merge_intervals - .iter() - .map(|(start, end)| RangePart::new(*start, *end)) - .collect::>()) - } - - /// Get gaps of undownload pieces - pub fn gaps(&mut self) -> Result, AgetError> { - let mut completed_intervals = self.completed_intervals()?; - let content_length = self.content_length()?; - completed_intervals.push(RangePart::new(content_length, content_length)); - - // find gaps - let mut gaps: Vec = Vec::new(); - // find first chunk - let RangePart { start, .. } = completed_intervals[0]; - if start > 0 { - gaps.push(RangePart::new(0, start - 1)); - } - - for (index, RangePart { end, .. }) in completed_intervals.iter().enumerate() { - if let Some(RangePart { - start: next_start, .. - }) = completed_intervals.get(index + 1) - { - if end + 1 < *next_start { - gaps.push(RangePart::new(end + 1, next_start - 1)); - } - } - } - - Ok(gaps) - } - - pub fn write_content_length(&mut self, content_length: u64) -> Result<(), AgetError> { - let buf = u64_to_u8x8(content_length); - self.inner.write(&buf, Some(SeekFrom::Start(0)))?; - Ok(()) - } - - pub fn write_interval(&mut self, interval: RangePart) -> Result<(), AgetError> { - let start = u64_to_u8x8(interval.start); - let end = u64_to_u8x8(interval.end); - let buf = [start, end].concat(); - self.inner.write(&buf, Some(SeekFrom::End(0)))?; - Ok(()) - } - - // Merge completed intervals and rewrite the aget file - pub fn rewrite(&mut self) -> Result<(), AgetError> { - let content_length = self.content_length()?; - let completed_intervals = self.completed_intervals()?; - - let mut buf: Vec = Vec::new(); - buf.extend(&u64_to_u8x8(content_length)); - for interval in completed_intervals.iter() { - buf.extend(&u64_to_u8x8(interval.start)); - buf.extend(&u64_to_u8x8(interval.end)); - } - - self.inner.set_len(0)?; - self.inner.write(buf.as_slice(), Some(SeekFrom::Start(0)))?; - - Ok(()) - } -} - -/// Create an integer value from its representation as a byte array in big endian. -pub fn u8x8_to_u64(u8x8: &[u8; 8]) -> u64 { - u64::from_be_bytes(*u8x8) -} - -/// Return the memory representation of this integer as a byte array in big-endian (network) byte -/// order. -pub fn u64_to_u8x8(u: u64) -> [u8; 8] { - u.to_be_bytes() -} diff --git a/src/task.rs b/src/task.rs deleted file mode 100644 index a517aa4..0000000 --- a/src/task.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::time::Duration; - -use awc::http::header; - -use futures::{channel::mpsc::Sender, stream::StreamExt}; -use futures_util::sink::SinkExt; - -use bytes::Bytes; - -use crate::{ - chunk::{RangePart, RangeStack}, - error::NetError, - request::AgetRequestOptions, -}; - -pub struct RequestTask { - range_stack: RangeStack, - sender: Sender<(RangePart, Bytes)>, -} - -impl RequestTask { - pub fn new(range_stack: RangeStack, sender: Sender<(RangePart, Bytes)>) -> RequestTask { - RequestTask { - range_stack, - sender, - } - } - - fn pop_range(&mut self) -> Option { - let mut stack = self.range_stack.borrow_mut(); - (*stack).pop() - } - - fn push_range(&mut self, range: RangePart) { - let mut stack = self.range_stack.borrow_mut(); - (*stack).push(range); - } - - pub async fn run(&mut self, mut options: AgetRequestOptions) -> Result<(), NetError> { - while let Some(range) = self.pop_range() { - let timeout = if options.is_concurrent() { - Duration::from_secs(60) - } else { - Duration::from_secs(10 * 24 * 60 * 60) // 10 days - }; - - let mut client_request = options.build()?.timeout(timeout); - if options.is_concurrent() { - client_request.headers_mut().insert( - header::RANGE, - format!("bytes={}-{}", range.start, range.end).parse()?, - ); - } - let resp = if let Some(body) = options.body() { - client_request.send_body(body).await - } else { - client_request.send().await - }; - - if resp.is_err() { - self.push_range(range); - continue; - } - - let mut resp = resp.unwrap(); - - let status = resp.status(); - - // handle redirect - if status.is_redirection() { - if let Some(location) = resp.headers().get(header::LOCATION) { - options.set_uri(location.to_str()?); - } - self.push_range(range); - continue; - } - - if !status.is_success() { - debug!("request error", status); - self.push_range(range); - continue; - } - - let mut range = range.clone(); - while let Some(chunk) = resp.next().await { - if let Ok(chunk) = chunk { - let len = chunk.len() as u64; - let start = range.start; - let end = start + len; - range.start = end; - - let mut sender = self.sender.clone(); - // the sended RangePart is a close interval as header `Range` - sender.send((RangePart::new(start, end - 1), chunk)).await?; - } else { - self.push_range(range); - break; - } - } - } - Ok(()) - } -} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index 6333ed0..0000000 --- a/src/util.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::path::Path; - -use awc::http::Uri; - -use term_size::dimensions; - -use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS}; - -use crate::error::{AgetError, ArgError, Result}; - -pub trait FindFilename { - fn find_file_name(&self) -> Result<&str, AgetError>; -} - -impl FindFilename for Uri { - fn find_file_name(&self) -> Result<&str, AgetError> { - let path = Path::new(self.path()); - if let Some(file_name) = path.file_name() { - Ok(file_name.to_str().unwrap()) - } else { - Err(AgetError::NoFilename) - } - } -} - -const MIN_TERMINAL_WIDTH: u64 = 60; - -pub fn terminal_width() -> u64 { - if let Some((width, _)) = dimensions() { - width as u64 - } else { - // for envrionment in which atty is not available, - // example, at ci of osx - MIN_TERMINAL_WIDTH - } -} - -const SIZES: [&'static str; 4] = ["B", "K", "M", "G"]; - -pub trait SizeOfFmt { - fn sizeof_fmt(&self) -> String; -} - -impl SizeOfFmt for u64 { - fn sizeof_fmt(&self) -> String { - let mut num = *self as f64; - for s in SIZES.iter() { - if num < 1024.0 { - return format!("{:.1}{}", num, s); - } - num /= 1024.0; - } - return format!("{:.1}{}", num, "G"); - } -} - -impl SizeOfFmt for f64 { - fn sizeof_fmt(&self) -> String { - let mut num = *self; - for s in SIZES.iter() { - if num < 1024.0 { - return format!("{:.1}{}", num, s); - } - num /= 1024.0; - } - return format!("{:.1}{}", num, "G"); - } -} - -pub trait LiteralSize { - fn literal_size(&self) -> Result; -} - -impl LiteralSize for &str { - fn literal_size(&self) -> Result { - let (num, unit) = self.split_at(self.len() - 1); - if unit.parse::().is_err() { - let mut num = num.parse::()?; - for s in SIZES.iter() { - if s.to_lowercase() == unit.to_lowercase() { - return Ok(num); - } else { - num *= 1024; - } - } - Ok(num) - } else { - let num = self.parse::()?; - Ok(num) - } - } -} - -pub trait TimeOfFmt { - fn timeof_fmt(&self) -> String; -} - -impl TimeOfFmt for u64 { - fn timeof_fmt(&self) -> String { - let mut num = *self as f64; - if num < 60.0 { - return format!("{:.0}s", num); - } - num /= 60.0; - if num < 60.0 { - return format!("{:.0}m", num); - } - num /= 60.0; - if num < 24.0 { - return format!("{:.0}h", num); - } - num /= 24.0; - return format!("{:.0}d", num); - } -} - -pub static mut DEBUG: bool = false; -pub static mut QUIET: bool = false; - -#[macro_export] -macro_rules! print_err { - ( $ctx:expr, $err:expr ) => { - eprintln!("[{}:{}] {}: {}", file!(), line!(), $ctx, $err); - }; -} - -#[macro_export] -macro_rules! debug { - ( $title:expr, $msg:expr ) => { - unsafe { - if crate::util::DEBUG { - eprintln!("[{}:{}] {}: {:#?}", file!(), line!(), $title, $msg); - } - } - }; - ( $title:expr ) => { - unsafe { - if crate::util::DEBUG { - eprintln!("[{}:{}] {}", file!(), line!(), $title); - } - } - }; -} - -const FRAGMENT: &AsciiSet = &CONTROLS - .add(b' ') - .remove(b'?') - .remove(b'/') - .remove(b':') - .remove(b'='); - -pub fn escape_nonascii(target: &str) -> String { - utf8_percent_encode(target, FRAGMENT).to_string() -}