diff --git a/Cargo.lock b/Cargo.lock index dbf1298f0..8c17c9776 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7296,13 +7296,17 @@ dependencies = [ "ethers-core 2.0.10", "ethers-providers 2.0.10", "hex", + "lazy_static", "log", "once_cell", "raiko-lib", + "reqwest 0.11.27", "risc0-zkvm", "serde", "serde_json", "serde_with 3.9.0", + "tokio", + "tokio-util", "tracing", "typetag", ] diff --git a/Cargo.toml b/Cargo.toml index 7159419db..f671f3b3b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,9 +57,9 @@ risc0-build = { version = "1.0.1" } risc0-binfmt = { version = "1.0.1" } # SP1 -sp1-sdk = { version = "1.0.1" } -sp1-zkvm = { version = "1.0.1" } -sp1-helper = { version = "1.0.1" } +sp1-sdk = { version = "1.0.1" } +sp1-zkvm = { version = "1.0.1" } +sp1-helper = { version = "1.0.1" } # alloy alloy-rlp = { version = "0.3.4", default-features = false } @@ -189,4 +189,4 @@ revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36- revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" } secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" } blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" } -alloy-serde = { git = "https://github.com/CeciliaZ030/alloy.git", branch = "v0.1.4-fix"} \ No newline at end of file +alloy-serde = { git = "https://github.com/CeciliaZ030/alloy.git", branch = "v0.1.4-fix" } diff --git a/Dockerfile b/Dockerfile index 653571ec4..a5ba91e02 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ ARG BUILD_FLAGS="" WORKDIR /opt/raiko COPY . . -RUN cargo build --release ${BUILD_FLAGS} --features "sgx" --features "docker_build" +RUN cargo build --release ${BUILD_FLAGS} --features "sgx,risc0" --features "docker_build" FROM gramineproject/gramine:1.6-jammy as runtime ENV DEBIAN_FRONTEND=noninteractive diff --git a/provers/risc0/driver/Cargo.toml b/provers/risc0/driver/Cargo.toml index d9a542262..91abe106c 100644 --- a/provers/risc0/driver/Cargo.toml +++ b/provers/risc0/driver/Cargo.toml @@ -34,6 +34,10 @@ typetag = { workspace = true, optional = true } serde_with = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } hex = { workspace = true, optional = true } +reqwest = { workspace = true, optional = true } +lazy_static = { workspace = true, optional = true } +tokio = { workspace = true } +tokio-util = { workspace = true } [features] enable = [ @@ -57,6 +61,8 @@ enable = [ "serde_with", "serde_json", "hex", + "reqwest", + "lazy_static" ] cuda = ["risc0-zkvm?/cuda"] metal = ["risc0-zkvm?/metal"] diff --git a/provers/risc0/driver/src/bonsai.rs b/provers/risc0/driver/src/bonsai.rs index e0dd9173c..5f57bfc92 100644 --- a/provers/risc0/driver/src/bonsai.rs +++ b/provers/risc0/driver/src/bonsai.rs @@ -16,6 +16,8 @@ use std::{ use crate::Risc0Param; +pub mod auto_scaling; + pub async fn verify_bonsai_receipt( image_id: Digest, expected_output: &O, @@ -194,6 +196,7 @@ pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> { let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?; let session = bonsai_sdk::alpha::SessionId { uuid }; session.stop(&client)?; + auto_scaling::shutdown_bonsai().await?; Ok(()) } diff --git a/provers/risc0/driver/src/bonsai/auto_scaling.rs b/provers/risc0/driver/src/bonsai/auto_scaling.rs new file mode 100644 index 000000000..e67808813 --- /dev/null +++ b/provers/risc0/driver/src/bonsai/auto_scaling.rs @@ -0,0 +1,204 @@ +use anyhow::{Error, Ok, Result}; +use lazy_static::lazy_static; +use reqwest::{header::HeaderMap, header::HeaderValue, header::CONTENT_TYPE, Client}; +use serde::Deserialize; +use std::env; +use tracing::{debug, error as trace_err}; + +#[derive(Debug, Deserialize, Default)] +struct ScalerResponse { + desired: u32, + current: u32, + pending: u32, +} +struct BonsaiAutoScaler { + url: String, + headers: HeaderMap, + client: Client, + on_setting_status: Option, +} + +impl BonsaiAutoScaler { + fn new(bonsai_api_url: String, api_key: String) -> Self { + let url = bonsai_api_url + "/workers"; + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + headers.insert("x-api-key", HeaderValue::from_str(&api_key).unwrap()); + + Self { + url, + headers, + client: Client::new(), + on_setting_status: None, + } + } + + async fn get_bonsai_gpu_num(&self) -> Result { + debug!("Requesting scaler status from: {}", self.url); + let response = self + .client + .get(self.url.clone()) + .headers(self.headers.clone()) + .send() + .await?; + + // Check if the request was successful + if response.status().is_success() { + // Parse the JSON response + let data: ScalerResponse = response.json().await.unwrap_or_default(); + debug!("Scaler status: {data:?}"); + Ok(data) + } else { + trace_err!("Request failed with status: {}", response.status()); + Err(Error::msg("Failed to get bonsai gpu num".to_string())) + } + } + + async fn set_bonsai_gpu_num(&mut self, gpu_num: u32) -> Result<()> { + if self.on_setting_status.is_some() { + // log an err if there is a race adjustment. + trace_err!("Last bonsai setting is not active, please check."); + } + + debug!("Requesting scaler status from: {}", self.url); + let response = self + .client + .post(self.url.clone()) + .headers(self.headers.clone()) + .body(gpu_num.to_string()) + .send() + .await?; + + // Check if the request was successful + if response.status().is_success() { + self.on_setting_status = Some(ScalerResponse { + desired: gpu_num, + current: 0, + pending: 0, + }); + Ok(()) + } else { + trace_err!("Request failed with status: {}", response.status()); + Err(Error::msg("Failed to get bonsai gpu num".to_string())) + } + } + + async fn wait_for_bonsai_config_active(&mut self, time_out_sec: u64) -> Result<()> { + match &self.on_setting_status { + None => Ok(()), + Some(setting) => { + // loop until some timeout + let start_time = std::time::Instant::now(); + let mut check_time = std::time::Instant::now(); + while check_time.duration_since(start_time).as_secs() < time_out_sec { + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + check_time = std::time::Instant::now(); + let current_bonsai_gpu_num = self.get_bonsai_gpu_num().await?; + if current_bonsai_gpu_num.current == setting.desired { + self.on_setting_status = None; + return Ok(()); + } + } + Err(Error::msg( + "checking bonsai config active timeout".to_string(), + )) + } + } + } +} + +lazy_static! { + static ref BONSAI_API_URL: String = + env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set"); + static ref BONSAI_API_KEY: String = + env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set"); + static ref MAX_BONSAI_GPU_NUM: u32 = env::var("MAX_BONSAI_GPU_NUM") + .unwrap_or_else(|_| "15".to_string()) + .parse() + .unwrap(); +} + +pub(crate) async fn maxpower_bonsai() -> Result<()> { + let mut auto_scaler = + BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string()); + let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?; + // either already maxed out or pending to be maxed out + if current_gpu_num.current == *MAX_BONSAI_GPU_NUM + && current_gpu_num.desired == *MAX_BONSAI_GPU_NUM + && current_gpu_num.pending == 0 + { + Ok(()) + } else { + auto_scaler.set_bonsai_gpu_num(*MAX_BONSAI_GPU_NUM).await?; + auto_scaler.wait_for_bonsai_config_active(300).await + } +} + +pub(crate) async fn shutdown_bonsai() -> Result<()> { + let mut auto_scaler = + BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string()); + let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?; + if current_gpu_num.current == 0 && current_gpu_num.pending == 0 && current_gpu_num.desired == 0 + { + Ok(()) + } else { + auto_scaler.set_bonsai_gpu_num(0).await?; + // wait few minute for the bonsai to cool down + auto_scaler.wait_for_bonsai_config_active(30).await + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::env; + use tokio; + + #[ignore] + #[tokio::test] + async fn test_bonsai_auto_scaler_get() { + let bonsai_url = env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set"); + let bonsai_key = env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set"); + let max_bonsai_gpu: u32 = env::var("MAX_BONSAI_GPU_NUM") + .unwrap_or_else(|_| "15".to_string()) + .parse() + .unwrap(); + let auto_scaler = BonsaiAutoScaler::new(bonsai_url, bonsai_key); + let scalar_status = auto_scaler.get_bonsai_gpu_num().await.unwrap(); + assert!(scalar_status.current <= max_bonsai_gpu); + assert_eq!( + scalar_status.desired, + scalar_status.current + scalar_status.pending + ); + } + + #[ignore] + #[tokio::test] + async fn test_bonsai_auto_scaler_set() { + let bonsai_url = env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set"); + let bonsai_key = env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set"); + let mut auto_scaler = BonsaiAutoScaler::new(bonsai_url, bonsai_key); + + auto_scaler + .set_bonsai_gpu_num(7) + .await + .expect("Failed to set bonsai gpu num"); + auto_scaler + .wait_for_bonsai_config_active(300) + .await + .unwrap(); + let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current; + assert_eq!(current_gpu_num, 7); + + auto_scaler + .set_bonsai_gpu_num(0) + .await + .expect("Failed to set bonsai gpu num"); + auto_scaler + .wait_for_bonsai_config_active(300) + .await + .unwrap(); + let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current; + assert_eq!(current_gpu_num, 0); + } +} diff --git a/provers/risc0/driver/src/lib.rs b/provers/risc0/driver/src/lib.rs index 8a0e36f6c..6e9920614 100644 --- a/provers/risc0/driver/src/lib.rs +++ b/provers/risc0/driver/src/lib.rs @@ -14,6 +14,7 @@ use std::fmt::Debug; use tracing::{debug, info as traicing_info}; use crate::{ + bonsai::auto_scaling::{maxpower_bonsai, shutdown_bonsai}, methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID}, snarks::verify_groth16_snark, }; @@ -70,6 +71,13 @@ impl Prover for Risc0Prover { debug!("elf code length: {}", RISC0_GUEST_ELF.len()); let encoded_input = to_vec(&input).expect("Could not serialize proving input!"); + if config.bonsai { + // make max speed bonsai + maxpower_bonsai() + .await + .expect("Failed to set max power on Bonsai"); + } + let result = maybe_prove::( &config, encoded_input, @@ -83,8 +91,8 @@ impl Prover for Risc0Prover { let journal: String = result.clone().unwrap().1.journal.encode_hex(); - // Create/verify Groth16 SNARK - let snark_proof = if config.snark { + // Create/verify Groth16 SNARK in bonsai + let snark_proof = if config.snark && config.bonsai { let Some((stark_uuid, stark_receipt)) = result else { return Err(ProverError::GuestError( "No STARK data to snarkify!".to_owned(), @@ -108,6 +116,13 @@ impl Prover for Risc0Prover { journal }; + if config.bonsai { + // shutdown max speed bonsai + shutdown_bonsai() + .await + .map_err(|e| ProverError::GuestError(e.to_string()))?; + } + Ok(Risc0Response { proof: snark_proof }.into()) } @@ -125,8 +140,7 @@ impl Prover for Risc0Prover { cancel_proof(uuid) .await .map_err(|e| ProverError::GuestError(e.to_string()))?; - id_store.remove_id(key).await?; - Ok(()) + id_store.remove_id(key).await } } diff --git a/provers/risc0/driver/src/methods/ecdsa.rs b/provers/risc0/driver/src/methods/ecdsa.rs index e7af708bf..7fe04062b 100644 --- a/provers/risc0/driver/src/methods/ecdsa.rs +++ b/provers/risc0/driver/src/methods/ecdsa.rs @@ -1,5 +1,5 @@ pub const ECDSA_ELF: &[u8] = include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/ecdsa"); pub const ECDSA_ID: [u32; 8] = [ - 3314277365, 903638368, 2823387338, 975292771, 2962241176, 3386670094, 1262198564, 423457744, + 1166688769, 1407190737, 3347938864, 1261472884, 3997842354, 3752365982, 4108615966, 2506107654, ]; diff --git a/provers/risc0/driver/src/methods/sha256.rs b/provers/risc0/driver/src/methods/sha256.rs index d2a93b856..4302e7320 100644 --- a/provers/risc0/driver/src/methods/sha256.rs +++ b/provers/risc0/driver/src/methods/sha256.rs @@ -1,5 +1,5 @@ pub const SHA256_ELF: &[u8] = include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/sha256"); pub const SHA256_ID: [u32; 8] = [ - 3506084161, 1146489446, 485833862, 3404354046, 3626029993, 1928006034, 3833244069, 3073098029, + 1030743442, 3697463329, 2083175350, 1726292372, 629109085, 444583534, 849554126, 3148184953, ]; diff --git a/provers/risc0/driver/src/methods/test_risc0_guest.rs b/provers/risc0/driver/src/methods/test_risc0_guest.rs index 8305a5555..7db0ebf0e 100644 --- a/provers/risc0/driver/src/methods/test_risc0_guest.rs +++ b/provers/risc0/driver/src/methods/test_risc0_guest.rs @@ -1,6 +1,6 @@ pub const TEST_RISC0_GUEST_ELF: &[u8] = include_bytes!( - "../../../guest/target/riscv32im-risc0-zkvm-elf/release/deps/risc0_guest-4b4f18d42a260659" + "../../../guest/target/riscv32im-risc0-zkvm-elf/release/deps/risc0_guest-3bef88267f07d7e2" ); pub const TEST_RISC0_GUEST_ID: [u32; 8] = [ - 3216516244, 2583889163, 799150854, 107525368, 1015178806, 1451965571, 3377528142, 1073775, + 947177299, 3433149683, 3077752115, 1716500464, 3011459317, 622725533, 247263939, 1661915565, ]; diff --git a/provers/risc0/driver/src/snarks.rs b/provers/risc0/driver/src/snarks.rs index e10af3e32..5cc00d232 100644 --- a/provers/risc0/driver/src/snarks.rs +++ b/provers/risc0/driver/src/snarks.rs @@ -70,11 +70,6 @@ abigen!( ]"# ); -// /// ABI encoding of the seal. -// pub fn abi_encode(seal: Vec) -> Result> { -// Ok(encode(seal)?.abi_encode()) -// } - /// encoding of the seal with selector. pub fn encode(seal: Vec) -> Result> { let verifier_parameters_digest = Groth16ReceiptVerifierParameters::default().digest();