Skip to content

Commit

Permalink
add oci blob download support
Browse files Browse the repository at this point in the history
  • Loading branch information
QaidVoid committed Jan 18, 2025
1 parent 8f4ec9d commit 45de669
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "soar-dl"
version = "0.2.0"
version = "0.3.0"
authors = ["Rabindra Dhakal <[email protected]>"]
description = "A fast download manager"
license = "MIT"
Expand Down
73 changes: 68 additions & 5 deletions src/downloader.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{
collections::HashMap,
fs::Permissions,
os::unix::fs::PermissionsExt,
path::Path,
path::{Path, PathBuf},
sync::{Arc, Mutex},
};

Expand All @@ -16,7 +17,7 @@ use url::Url;

use crate::{
error::DownloadError,
oci::{OciClient, Reference},
oci::{OciClient, OciLayer, Reference},
utils::{extract_filename, is_elf},
};

Expand Down Expand Up @@ -120,10 +121,63 @@ impl Downloader {
Ok(filename)
}

pub async fn download_blob(
&self,
client: OciClient,
options: DownloadOptions,
) -> Result<(), DownloadError> {
let reference = client.reference.clone();
let digest = reference.tag;
let downloaded_bytes = Arc::new(Mutex::new(0u64));
let output_path = options.output_path;
let ref_name = reference
.package
.rsplit_once('/')
.map_or(digest.clone(), |(_, name)| name.to_string());
let file_path = output_path.unwrap_or_else(|| ref_name.clone());
let file_path = if file_path.ends_with('/') {
fs::create_dir_all(&file_path).await?;
format!("{}/{}", file_path.trim_end_matches('/'), ref_name)
} else {
file_path
};

let fake_layer = OciLayer {
media_type: String::from("application/octet-stream"),
digest: digest.clone(),
size: 0,
annotations: HashMap::new(),
};

let cb_clone = options.progress_callback.clone();
client
.pull_layer(&fake_layer, &file_path, move |bytes, total_bytes| {
if let Some(ref callback) = cb_clone {
if total_bytes > 0 {
callback(DownloadState::Preparing(total_bytes));
}
let mut current = downloaded_bytes.lock().unwrap();
*current = bytes;
callback(DownloadState::Progress(*current));
}
})
.await?;

if let Some(ref callback) = options.progress_callback {
callback(DownloadState::Complete);
}

Ok(())
}

pub async fn download_oci(&self, options: DownloadOptions) -> Result<(), DownloadError> {
let url = options.url.clone();
let reference: Reference = url.into();
let oci_client = OciClient::new(reference);
let oci_client = OciClient::new(&reference);

if reference.tag.starts_with("sha256:") {
return self.download_blob(oci_client, options).await;
}

let manifest = oci_client.manifest().await.unwrap();

Expand All @@ -136,16 +190,25 @@ impl Downloader {

let downloaded_bytes = Arc::new(Mutex::new(0u64));
let outdir = options.output_path;
let base_path = if let Some(dir) = outdir {
fs::create_dir_all(&dir).await?;
PathBuf::from(dir)
} else {
PathBuf::new()
};

for layer in manifest.layers {
let client_clone = oci_client.clone();
let cb_clone = options.progress_callback.clone();
let downloaded_bytes = downloaded_bytes.clone();
let outdir = outdir.clone();
let Some(filename) = layer.get_title() else {
continue;
};
let file_path = base_path.join(filename);

let task = task::spawn(async move {
client_clone
.pull_layer(&layer, outdir, move |bytes| {
.pull_layer(&layer, &file_path, move |bytes, _| {
if let Some(ref callback) = cb_clone {
let mut current = downloaded_bytes.lock().unwrap();
*current = bytes;
Expand Down
89 changes: 52 additions & 37 deletions src/oci.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{
collections::HashMap,
fs::Permissions,
os::unix::fs::PermissionsExt,
path::{Path, PathBuf},
};

Expand All @@ -11,7 +13,7 @@ use tokio::{
io::AsyncWriteExt,
};

use crate::error::DownloadError;
use crate::{error::DownloadError, utils::is_elf};

#[derive(Deserialize)]
pub struct OciLayer {
Expand Down Expand Up @@ -41,43 +43,55 @@ pub struct OciManifest {
#[derive(Clone)]
pub struct OciClient {
client: reqwest::Client,
reference: Reference,
pub reference: Reference,
}

#[derive(Clone)]
pub struct Reference {
package: String,
tag: String,
pub package: String,
pub tag: String,
}

impl From<&str> for Reference {
fn from(value: &str) -> Self {
let paths = value.trim_start_matches("ghcr.io/");
let (package, tag) = paths.split_once(':').unwrap_or((paths, "latest"));

// <package>@sha256:<digest>
if let Some((package, digest)) = paths.split_once("@") {
return Self {
package: package.to_string(),
tag: digest.to_string(),
};
}

// <package>:<tag>
if let Some((package, tag)) = paths.split_once(':') {
return Self {
package: package.to_string(),
tag: tag.to_string(),
};
}

Self {
package: package.to_string(),
tag: tag.to_string(),
package: paths.to_string(),
tag: "latest".to_string(),
}
}
}

impl From<String> for Reference {
fn from(value: String) -> Self {
let paths = value.trim_start_matches("ghcr.io/");
let (package, tag) = paths.split_once(':').unwrap_or((paths, "latest"));

Self {
package: package.to_string(),
tag: tag.to_string(),
}
value.as_str().into()
}
}

impl OciClient {
pub fn new(reference: Reference) -> Self {
pub fn new(reference: &Reference) -> Self {
let client = reqwest::Client::new();
Self { client, reference }
Self {
client,
reference: reference.clone(),
}
}

pub fn headers(&self) -> HeaderMap {
Expand Down Expand Up @@ -114,14 +128,15 @@ impl OciClient {
Ok(manifest)
}

pub async fn pull_layer<F, P: AsRef<Path>>(
pub async fn pull_layer<F, P>(
&self,
layer: &OciLayer,
output_dir: Option<P>,
output_path: P,
progress_callback: F,
) -> Result<u64, DownloadError>
where
F: Fn(u64) + Send + 'static,
P: AsRef<Path>,
F: Fn(u64, u64) + Send + 'static,
{
let blob_url = format!(
"https://ghcr.io/v2/{}/blobs/{}",
Expand All @@ -142,22 +157,11 @@ impl OciClient {
});
}

let Some(filename) = layer.get_title() else {
// skip if layer doesn't contain title
return Ok(0);
};

let (temp_path, final_path) = if let Some(output_dir) = output_dir {
let output_dir = output_dir.as_ref();
fs::create_dir_all(output_dir).await?;
let final_path = output_dir.join(format!("{filename}"));
let temp_path = output_dir.join(format!("{filename}.part"));
(temp_path, final_path)
} else {
let final_path = PathBuf::from(&filename);
let temp_path = PathBuf::from(format!("{filename}.part"));
(temp_path, final_path)
};
let content_length = resp.content_length().unwrap_or(0);
progress_callback(0, content_length);

let output_path = output_path.as_ref();
let temp_path = PathBuf::from(&format!("{}.part", output_path.display()));

let mut file = OpenOptions::new()
.create(true)
Expand All @@ -173,11 +177,15 @@ impl OciClient {
let chunk_size = chunk.len() as u64;
file.write_all(&chunk).await.unwrap();

progress_callback(chunk_size);
progress_callback(chunk_size, 0);
total_bytes_downloaded += chunk_size;
}

fs::rename(&temp_path, &final_path).await?;
fs::rename(&temp_path, &output_path).await?;

if is_elf(&output_path).await {
fs::set_permissions(&output_path, Permissions::from_mode(0o755)).await?;
}

Ok(total_bytes_downloaded)
}
Expand All @@ -189,4 +197,11 @@ impl OciLayer {
.get("org.opencontainers.image.title")
.cloned()
}

pub fn set_title(&mut self, title: &str) {
self.annotations.insert(
"org.opencontainers.image.title".to_string(),
title.to_string(),
);
}
}

0 comments on commit 45de669

Please sign in to comment.