Skip to content

Commit

Permalink
Merge branch 'main' into ontake-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
petarvujovic98 authored Aug 19, 2024
2 parents a4364bb + 34c1348 commit dd8b440
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 24 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ risc0-build = { version = "1.0.1" }
risc0-binfmt = { version = "1.0.1" }

# SP1
sp1-sdk = { version = "1.0.1" }
sp1-zkvm = { version = "1.0.1" }
sp1-helper = { version = "1.0.1" }
sp1-sdk = { version = "1.0.1" }
sp1-zkvm = { version = "1.0.1" }
sp1-helper = { version = "1.0.1" }

# alloy
alloy-rlp = { version = "0.3.4", default-features = false }
Expand Down Expand Up @@ -191,4 +191,4 @@ revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-
revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" }
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }
alloy-serde = { git = "https://github.com/CeciliaZ030/alloy.git", branch = "v0.1.4-fix"}
alloy-serde = { git = "https://github.com/CeciliaZ030/alloy.git", branch = "v0.1.4-fix" }
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ARG BUILD_FLAGS=""

WORKDIR /opt/raiko
COPY . .
RUN cargo build --release ${BUILD_FLAGS} --features "sgx" --features "docker_build"
RUN cargo build --release ${BUILD_FLAGS} --features "sgx,risc0" --features "docker_build"

FROM gramineproject/gramine:1.6-jammy as runtime
ENV DEBIAN_FRONTEND=noninteractive
Expand Down
8 changes: 7 additions & 1 deletion provers/risc0/driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ typetag = { workspace = true, optional = true }
serde_with = { workspace = true, optional = true }
serde_json = { workspace = true, optional = true }
hex = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
lazy_static = { workspace = true, optional = true }
tokio = { workspace = true }
tokio-util = { workspace = true }

[features]
enable = [
Expand All @@ -57,8 +61,10 @@ enable = [
"serde_with",
"serde_json",
"hex",
"reqwest",
"lazy_static"
]
cuda = ["risc0-zkvm?/cuda"]
metal = ["risc0-zkvm?/metal"]
bench = []

bonsai-auto-scaling = []
9 changes: 9 additions & 0 deletions provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use std::{

use crate::Risc0Param;

#[cfg(feature = "bonsai-auto-scaling")]
pub mod auto_scaling;

pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
image_id: Digest,
expected_output: &O,
Expand Down Expand Up @@ -116,6 +119,10 @@ pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOw
info!("Loaded locally cached stark receipt {receipt_label:?}");
(cached_data.0, cached_data.1, true)
} else if param.bonsai {
#[cfg(feature = "bonsai-auto-scaling")]
auto_scaling::maxpower_bonsai()
.await
.expect("Failed to set max power on Bonsai");
// query bonsai service until it works
loop {
match prove_bonsai(
Expand Down Expand Up @@ -194,6 +201,8 @@ pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let session = bonsai_sdk::alpha::SessionId { uuid };
session.stop(&client)?;
#[cfg(feature = "bonsai-auto-scaling")]
auto_scaling::shutdown_bonsai().await?;
Ok(())
}

Expand Down
227 changes: 227 additions & 0 deletions provers/risc0/driver/src/bonsai/auto_scaling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
#![cfg(feature = "bonsai-auto-scaling")]

use anyhow::{Error, Ok, Result};
use lazy_static::lazy_static;
use log::info;
use once_cell::sync::Lazy;
use reqwest::{header::HeaderMap, header::HeaderValue, header::CONTENT_TYPE, Client};
use serde::Deserialize;
use std::env;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, error as trace_err};

#[derive(Debug, Deserialize, Default)]
struct ScalerResponse {
desired: u32,
current: u32,
pending: u32,
}
struct BonsaiAutoScaler {
url: String,
headers: HeaderMap,
client: Client,
on_setting_status: Option<ScalerResponse>,
}

impl BonsaiAutoScaler {
fn new(bonsai_api_url: String, api_key: String) -> Self {
let url = bonsai_api_url + "/workers";
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert("x-api-key", HeaderValue::from_str(&api_key).unwrap());

Self {
url,
headers,
client: Client::new(),
on_setting_status: None,
}
}

async fn get_bonsai_gpu_num(&self) -> Result<ScalerResponse> {
debug!("Requesting scaler status from: {}", self.url);
let response = self
.client
.get(self.url.clone())
.headers(self.headers.clone())
.send()
.await?;

// Check if the request was successful
if response.status().is_success() {
// Parse the JSON response
let data: ScalerResponse = response.json().await.unwrap_or_default();
debug!("Scaler status: {data:?}");
Ok(data)
} else {
trace_err!("Request failed with status: {}", response.status());
Err(Error::msg("Failed to get bonsai gpu num".to_string()))
}
}

async fn set_bonsai_gpu_num(&mut self, gpu_num: u32) -> Result<()> {
if self.on_setting_status.is_some() {
// log an err if there is a race adjustment.
trace_err!("Last bonsai setting is not active, please check.");
}

debug!("Requesting scaler status from: {}", self.url);
let response = self
.client
.post(self.url.clone())
.headers(self.headers.clone())
.body(gpu_num.to_string())
.send()
.await?;

// Check if the request was successful
if response.status().is_success() {
self.on_setting_status = Some(ScalerResponse {
desired: gpu_num,
current: 0,
pending: 0,
});
Ok(())
} else {
trace_err!("Request failed with status: {}", response.status());
Err(Error::msg("Failed to get bonsai gpu num".to_string()))
}
}

async fn wait_for_bonsai_config_active(&mut self, time_out_sec: u64) -> Result<()> {
match &self.on_setting_status {
None => Ok(()),
Some(setting) => {
// loop until some timeout
let start_time = std::time::Instant::now();
let mut check_time = std::time::Instant::now();
while check_time.duration_since(start_time).as_secs() < time_out_sec {
tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
check_time = std::time::Instant::now();
let current_bonsai_gpu_num = self.get_bonsai_gpu_num().await?;
if current_bonsai_gpu_num.current == setting.desired {
self.on_setting_status = None;
return Ok(());
}
}
Err(Error::msg(
"checking bonsai config active timeout".to_string(),
))
}
}
}
}

lazy_static! {
static ref BONSAI_API_URL: String =
env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set");
static ref BONSAI_API_KEY: String =
env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
static ref MAX_BONSAI_GPU_NUM: u32 = env::var("MAX_BONSAI_GPU_NUM")
.unwrap_or_else(|_| "15".to_string())
.parse()
.unwrap();
}

static AUTO_SCALER: Lazy<Arc<Mutex<BonsaiAutoScaler>>> = Lazy::new(|| {
Arc::new(Mutex::new(BonsaiAutoScaler::new(
BONSAI_API_URL.to_string(),
BONSAI_API_KEY.to_string(),
)))
});

static REF_COUNT: Lazy<Arc<Mutex<u32>>> = Lazy::new(|| Arc::new(Mutex::new(0)));

pub(crate) async fn maxpower_bonsai() -> Result<()> {
let mut ref_count = REF_COUNT.lock().await;
*ref_count += 1;

let mut auto_scaler = AUTO_SCALER.lock().await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
// either already maxed out or pending to be maxed out
if current_gpu_num.current == *MAX_BONSAI_GPU_NUM
&& current_gpu_num.desired == *MAX_BONSAI_GPU_NUM
&& current_gpu_num.pending == 0
{
Ok(())
} else {
info!("setting bonsai gpu num to: {:?}", *MAX_BONSAI_GPU_NUM);
auto_scaler.set_bonsai_gpu_num(*MAX_BONSAI_GPU_NUM).await?;
auto_scaler.wait_for_bonsai_config_active(900).await
}
}

pub(crate) async fn shutdown_bonsai() -> Result<()> {
let mut ref_count = REF_COUNT.lock().await;
*ref_count = ref_count.saturating_sub(1);

if *ref_count == 0 {
let mut auto_scaler = AUTO_SCALER.lock().await;
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
if current_gpu_num.current == 0
&& current_gpu_num.desired == 0
&& current_gpu_num.pending == 0
{
Ok(())
} else {
info!("setting bonsai gpu num to: 0");
auto_scaler.set_bonsai_gpu_num(0).await?;
auto_scaler.wait_for_bonsai_config_active(90).await
}
} else {
Ok(())
}
}

#[cfg(test)]
mod test {
use super::*;
use std::env;
use tokio;

#[ignore]
#[tokio::test]
async fn test_bonsai_auto_scaler_get() {
let bonsai_url = env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set");
let bonsai_key = env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
let max_bonsai_gpu: u32 = env::var("MAX_BONSAI_GPU_NUM")
.unwrap_or_else(|_| "15".to_string())
.parse()
.unwrap();
let auto_scaler = BonsaiAutoScaler::new(bonsai_url, bonsai_key);
let scalar_status = auto_scaler.get_bonsai_gpu_num().await.unwrap();
assert!(scalar_status.current <= max_bonsai_gpu);
assert_eq!(
scalar_status.desired,
scalar_status.current + scalar_status.pending
);
}

#[ignore]
#[tokio::test]
async fn test_bonsai_auto_scaler_set() {
let bonsai_url = env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set");
let bonsai_key = env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
let mut auto_scaler = BonsaiAutoScaler::new(bonsai_url, bonsai_key);

auto_scaler
.set_bonsai_gpu_num(7)
.await
.expect("Failed to set bonsai gpu num");
auto_scaler
.wait_for_bonsai_config_active(600)
.await
.unwrap();
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
assert_eq!(current_gpu_num, 7);

auto_scaler
.set_bonsai_gpu_num(0)
.await
.expect("Failed to set bonsai gpu num");
auto_scaler.wait_for_bonsai_config_active(60).await.unwrap();
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
assert_eq!(current_gpu_num, 0);
}
}
26 changes: 17 additions & 9 deletions provers/risc0/driver/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#![cfg(feature = "enable")]

#[cfg(feature = "bonsai-auto-scaling")]
use crate::bonsai::auto_scaling::shutdown_bonsai;
use crate::{
methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID},
snarks::verify_groth16_snark,
};
use alloy_primitives::B256;
use hex::ToHex;
use log::warn;
Expand All @@ -13,11 +19,6 @@ use serde_with::serde_as;
use std::fmt::Debug;
use tracing::{debug, info as traicing_info};

use crate::{
methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID},
snarks::verify_groth16_snark,
};

pub use bonsai::*;

pub mod bonsai;
Expand Down Expand Up @@ -83,8 +84,8 @@ impl Prover for Risc0Prover {

let journal: String = result.clone().unwrap().1.journal.encode_hex();

// Create/verify Groth16 SNARK
let snark_proof = if config.snark {
// Create/verify Groth16 SNARK in bonsai
let snark_proof = if config.snark && config.bonsai {
let Some((stark_uuid, stark_receipt)) = result else {
return Err(ProverError::GuestError(
"No STARK data to snarkify!".to_owned(),
Expand All @@ -108,6 +109,14 @@ impl Prover for Risc0Prover {
journal
};

#[cfg(feature = "bonsai-auto-scaling")]
if config.bonsai {
// shutdown bonsai
shutdown_bonsai()
.await
.map_err(|e| ProverError::GuestError(e.to_string()))?;
}

Ok(Risc0Response { proof: snark_proof }.into())
}

Expand All @@ -125,8 +134,7 @@ impl Prover for Risc0Prover {
cancel_proof(uuid)
.await
.map_err(|e| ProverError::GuestError(e.to_string()))?;
id_store.remove_id(key).await?;
Ok(())
id_store.remove_id(key).await
}
}

Expand Down
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/ecdsa.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const ECDSA_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/ecdsa");
pub const ECDSA_ID: [u32; 8] = [
3314277365, 903638368, 2823387338, 975292771, 2962241176, 3386670094, 1262198564, 423457744,
1166688769, 1407190737, 3347938864, 1261472884, 3997842354, 3752365982, 4108615966, 2506107654,
];
Loading

0 comments on commit dd8b440

Please sign in to comment.