Skip to content

Commit

Permalink
feat: impl reqpool (#447)
Browse files Browse the repository at this point in the history
* feat: impl reqpool

* feat(reqpool): mock pool

* feat(reqpool): remove Pool trait

* feat(reqpool): remove memory pool

* test(reqpool): add case for multiple redis pools

* chore: allow unused_imports

* feat(reqpool): update Cargo.toml

* chore(reqpool): impl Display for Status

* feat(reqpool): remove feature "enable-mock"
keroro520 authored Jan 22, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent ac752ee commit 9a243c0
Showing 8 changed files with 804 additions and 0 deletions.
28 changes: 28 additions & 0 deletions reqpool/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[package]
name = "raiko-reqpool"
version = "0.1.0"
authors = ["Taiko Labs"]
edition = "2021"

[dependencies]
raiko-lib = { workspace = true }
raiko-core = { workspace = true }
raiko-redis-derive = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
serde = { workspace = true }
serde_json = { workspace = true }
serde_with = { workspace = true }
tracing = { workspace = true }
tokio = { workspace = true }
async-trait = { workspace = true }
redis = { workspace = true }
backoff = { workspace = true }
derive-getters = { workspace = true }
proc-macro2 = { workspace = true }
quote = { workspace = true }
syn = { workspace = true }
alloy-primitives = { workspace = true }
lazy_static = { workspace = true }

[features]
test-utils = []
10 changes: 10 additions & 0 deletions reqpool/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
/// The configuration for the redis-backend request pool
pub struct RedisPoolConfig {
/// The URL of the Redis database, e.g. "redis://localhost:6379"
pub redis_url: String,
/// The TTL of the Redis database
pub redis_ttl: u64,
}
16 changes: 16 additions & 0 deletions reqpool/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
mod config;
mod macros;
mod mock;
mod redis_pool;
mod request;
mod utils;

// Re-export
pub use config::RedisPoolConfig;
pub use mock::{mock_redis_pool, MockRedisConnection};
pub use redis_pool::Pool;
pub use request::{
AggregationRequestEntity, AggregationRequestKey, RequestEntity, RequestKey,
SingleProofRequestEntity, SingleProofRequestKey, Status, StatusWithContext,
};
pub use utils::proof_key_to_hack_request_key;
44 changes: 44 additions & 0 deletions reqpool/src/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/// This macro implements the Display trait for a type by using serde_json's pretty printing.
/// If the type cannot be serialized to JSON, it falls back to using Debug formatting.
///
/// # Example
///
/// ```rust
/// use serde::{Serialize, Deserialize};
///
/// #[derive(Debug, Serialize, Deserialize)]
/// struct Person {
/// name: String,
/// age: u32
/// }
///
/// impl_display_using_json_pretty!(Person);
///
/// let person = Person {
/// name: "John".to_string(),
/// age: 30
/// };
///
/// // Will print:
/// // {
/// // "name": "John",
/// // "age": 30
/// // }
/// println!("{}", person);
/// ```
///
/// The type must implement serde's Serialize trait for JSON serialization to work.
/// If serialization fails, it will fall back to using the Debug implementation.
#[macro_export]
macro_rules! impl_display_using_json_pretty {
($type:ty) => {
impl std::fmt::Display for $type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match serde_json::to_string_pretty(self) {
Ok(s) => write!(f, "{}", s),
Err(_) => write!(f, "{:?}", self),
}
}
}
};
}
153 changes: 153 additions & 0 deletions reqpool/src/mock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use crate::{Pool, RedisPoolConfig};
use lazy_static::lazy_static;
use redis::{RedisError, RedisResult};
use serde::Serialize;
use serde_json::{json, Value};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};

type SingleStorage = Arc<Mutex<HashMap<Value, Value>>>;
type GlobalStorage = Mutex<HashMap<String, SingleStorage>>;

lazy_static! {
// #{redis_url => single_storage}
//
// We use redis_url to distinguish different redis database for tests, to prevent
// data race problem when running multiple tests.
static ref GLOBAL_STORAGE: GlobalStorage = Mutex::new(HashMap::new());
}

pub struct MockRedisConnection {
storage: SingleStorage,
}

impl MockRedisConnection {
pub fn new(redis_url: String) -> Self {
let mut global = GLOBAL_STORAGE.lock().unwrap();
Self {
storage: global
.entry(redis_url)
.or_insert_with(|| Arc::new(Mutex::new(HashMap::new())))
.clone(),
}
}

pub fn set_ex<K: Serialize, V: Serialize>(
&mut self,
key: K,
val: V,
_ttl: u64,
) -> RedisResult<()> {
let mut lock = self.storage.lock().unwrap();
lock.insert(json!(key), json!(val));
Ok(())
}

pub fn get<K: Serialize, V: serde::de::DeserializeOwned>(&mut self, key: &K) -> RedisResult<V> {
let lock = self.storage.lock().unwrap();
match lock.get(&json!(key)) {
None => Err(RedisError::from((redis::ErrorKind::TypeError, "not found"))),
Some(v) => serde_json::from_value(v.clone()).map_err(|e| {
RedisError::from((
redis::ErrorKind::TypeError,
"deserialization error",
e.to_string(),
))
}),
}
}

pub fn del<K: Serialize>(&mut self, key: K) -> RedisResult<usize> {
let mut lock = self.storage.lock().unwrap();
if lock.remove(&json!(key)).is_none() {
Ok(0)
} else {
Ok(1)
}
}

pub fn keys<K: serde::de::DeserializeOwned>(&mut self, key: &str) -> RedisResult<Vec<K>> {
assert_eq!(key, "*", "mock redis only supports '*'");

let lock = self.storage.lock().unwrap();
Ok(lock
.keys()
.map(|k| serde_json::from_value(k.clone()).unwrap())
.collect())
}
}

/// Return the mock redis pool with the given id.
///
/// This is used for testing. Please use the test case name as the id to prevent data race.
pub fn mock_redis_pool<S: ToString>(id: S) -> Pool {
let config = RedisPoolConfig {
redis_ttl: 111,
redis_url: format!("redis://{}:6379", id.to_string()),
};
Pool::open(config).unwrap()
}

#[cfg(test)]
mod tests {
use super::*;
use redis::RedisResult;

#[test]
fn test_mock_redis_pool() {
let mut pool = mock_redis_pool("test_mock_redis_pool");
let mut conn = pool.conn().expect("mock conn");

let key = "hello".to_string();
let val = "world".to_string();
conn.set_ex(key.clone(), val.clone(), 111)
.expect("mock set_ex");

let actual: RedisResult<String> = conn.get(&key);
assert_eq!(actual, Ok(val));

let _ = conn.del(&key);
let actual: RedisResult<String> = conn.get(&key);
assert!(actual.is_err());
}

#[test]
fn test_mock_multiple_redis_pool() {
let mut pool1 = mock_redis_pool("test_mock_multiple_redis_pool_1");
let mut pool2 = mock_redis_pool("test_mock_multiple_redis_pool_2");
let mut conn1 = pool1.conn().expect("mock conn");
let mut conn2 = pool2.conn().expect("mock conn");

let key = "hello".to_string();
let world = "world".to_string();

{
conn1
.set_ex(key.clone(), world.clone(), 111)
.expect("mock set_ex");
let actual: RedisResult<String> = conn1.get(&key);
assert_eq!(actual, Ok(world.clone()));
}

{
let actual: RedisResult<String> = conn2.get(&key);
assert!(actual.is_err());
}

{
let meme = "meme".to_string();
conn2
.set_ex(key.clone(), meme.clone(), 111)
.expect("mock set_ex");
let actual: RedisResult<String> = conn2.get(&key);
assert_eq!(actual, Ok(meme));
}

{
let actual: RedisResult<String> = conn1.get(&key);
assert_eq!(actual, Ok(world));
}
}
}
222 changes: 222 additions & 0 deletions reqpool/src/redis_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use crate::{
impl_display_using_json_pretty, proof_key_to_hack_request_key, RedisPoolConfig, RequestEntity,
RequestKey, StatusWithContext,
};
use backoff::{exponential::ExponentialBackoff, SystemClock};
use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult};
use raiko_redis_derive::RedisValue;
#[allow(unused_imports)]
use redis::{Client, Commands, RedisResult};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, time::Duration};

#[derive(Debug, Clone)]
pub struct Pool {
client: Client,
config: RedisPoolConfig,
}

impl Pool {
pub fn add(
&mut self,
request_key: RequestKey,
request_entity: RequestEntity,
status: StatusWithContext,
) -> Result<(), String> {
tracing::info!("RedisPool.add: {request_key}, {status}");
let request_entity_and_status = RequestEntityAndStatus {
entity: request_entity,
status,
};
self.conn()
.map_err(|e| e.to_string())?
.set_ex(
request_key,
request_entity_and_status,
self.config.redis_ttl,
)
.map_err(|e| e.to_string())?;
Ok(())
}

pub fn remove(&mut self, request_key: &RequestKey) -> Result<usize, String> {
tracing::info!("RedisPool.remove: {request_key}");
let result: usize = self
.conn()
.map_err(|e| e.to_string())?
.del(request_key)
.map_err(|e| e.to_string())?;
Ok(result)
}

pub fn get(
&mut self,
request_key: &RequestKey,
) -> Result<Option<(RequestEntity, StatusWithContext)>, String> {
let result: RedisResult<RequestEntityAndStatus> =
self.conn().map_err(|e| e.to_string())?.get(request_key);
match result {
Ok(value) => Ok(Some(value.into())),
Err(e) if e.kind() == redis::ErrorKind::TypeError => Ok(None),
Err(e) => Err(e.to_string()),
}
}

pub fn get_status(
&mut self,
request_key: &RequestKey,
) -> Result<Option<StatusWithContext>, String> {
self.get(request_key).map(|v| v.map(|v| v.1))
}

pub fn update_status(
&mut self,
request_key: RequestKey,
status: StatusWithContext,
) -> Result<StatusWithContext, String> {
tracing::info!("RedisPool.update_status: {request_key}, {status}");
match self.get(&request_key)? {
Some((entity, old_status)) => {
self.add(request_key, entity, status)?;
Ok(old_status)
}
None => Err("Request not found".to_string()),
}
}

pub fn list(&mut self) -> Result<HashMap<RequestKey, StatusWithContext>, String> {
let mut conn = self.conn().map_err(|e| e.to_string())?;
let keys: Vec<RequestKey> = conn.keys("*").map_err(|e| e.to_string())?;

let mut result = HashMap::new();
for key in keys {
if let Ok(Some((_, status))) = self.get(&key) {
result.insert(key, status);
}
}

Ok(result)
}
}

#[async_trait::async_trait]
impl IdStore for Pool {
async fn read_id(&mut self, proof_key: ProofKey) -> ProverResult<String> {
let hack_request_key = proof_key_to_hack_request_key(proof_key);

tracing::info!("RedisPool.read_id: {hack_request_key}");

let result: RedisResult<String> = self
.conn()
.map_err(|e| e.to_string())?
.get(&hack_request_key);
match result {
Ok(value) => Ok(value.into()),
Err(e) => Err(ProverError::StoreError(e.to_string())),
}
}
}

#[async_trait::async_trait]
impl IdWrite for Pool {
async fn store_id(&mut self, proof_key: ProofKey, id: String) -> ProverResult<()> {
let hack_request_key = proof_key_to_hack_request_key(proof_key);

tracing::info!("RedisPool.store_id: {hack_request_key}, {id}");

self.conn()
.map_err(|e| e.to_string())?
.set_ex(hack_request_key, id, self.config.redis_ttl)
.map_err(|e| ProverError::StoreError(e.to_string()))?;
Ok(())
}

async fn remove_id(&mut self, proof_key: ProofKey) -> ProverResult<()> {
let hack_request_key = proof_key_to_hack_request_key(proof_key);

tracing::info!("RedisPool.remove_id: {hack_request_key}");

self.conn()
.map_err(|e| e.to_string())?
.del(hack_request_key)
.map_err(|e| ProverError::StoreError(e.to_string()))?;
Ok(())
}
}

impl Pool {
pub fn open(config: RedisPoolConfig) -> Result<Self, redis::RedisError> {
tracing::info!("RedisPool.open: connecting to redis: {}", config.redis_url);

let client = Client::open(config.redis_url.clone())?;
Ok(Self { client, config })
}

#[cfg(any(test, feature = "test-utils"))]
pub(crate) fn conn(&mut self) -> Result<crate::mock::MockRedisConnection, redis::RedisError> {
Ok(crate::mock::MockRedisConnection::new(
self.config.redis_url.clone(),
))
}

#[cfg(not(any(test, feature = "test-utils")))]
fn conn(&mut self) -> Result<redis::Connection, redis::RedisError> {
self.redis_conn()
}

#[allow(dead_code)]
fn redis_conn(&mut self) -> Result<redis::Connection, redis::RedisError> {
let backoff: ExponentialBackoff<SystemClock> = ExponentialBackoff {
initial_interval: Duration::from_secs(10),
max_interval: Duration::from_secs(60),
max_elapsed_time: Some(Duration::from_secs(300)),
..Default::default()
};

backoff::retry(backoff, || match self.client.get_connection() {
Ok(conn) => Ok(conn),
Err(e) => {
tracing::error!(
"RedisPool.get_connection: failed to connect to redis: {e:?}, retrying..."
);

self.client = redis::Client::open(self.config.redis_url.clone())?;
Err(backoff::Error::Transient {
err: e,
retry_after: None,
})
}
})
.map_err(|e| match e {
backoff::Error::Transient {
err,
retry_after: _,
}
| backoff::Error::Permanent(err) => err,
})
}
}

/// A internal wrapper for request entity and status, used for redis serialization
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue)]
struct RequestEntityAndStatus {
entity: RequestEntity,
status: StatusWithContext,
}

impl From<(RequestEntity, StatusWithContext)> for RequestEntityAndStatus {
fn from(value: (RequestEntity, StatusWithContext)) -> Self {
Self {
entity: value.0,
status: value.1,
}
}
}

impl From<RequestEntityAndStatus> for (RequestEntity, StatusWithContext) {
fn from(value: RequestEntityAndStatus) -> Self {
(value.entity, value.status)
}
}

impl_display_using_json_pretty!(RequestEntityAndStatus);
304 changes: 304 additions & 0 deletions reqpool/src/request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
use crate::impl_display_using_json_pretty;
use alloy_primitives::Address;
use chrono::{DateTime, Utc};
use derive_getters::Getters;
use raiko_core::interfaces::ProverSpecificOpts;
use raiko_lib::{
input::BlobProofType,
primitives::{ChainId, B256},
proof_type::ProofType,
prover::Proof,
};
use raiko_redis_derive::RedisValue;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use std::collections::HashMap;

#[derive(RedisValue, PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord)]
#[serde(rename_all = "snake_case")]
/// The status of a request
pub enum Status {
// === Normal status ===
/// The request is registered but not yet started
Registered,

/// The request is in progress
WorkInProgress,

// /// The request is in progress of proving
// WorkInProgressProving {
// /// The proof ID
// /// For SP1 and RISC0 proof type, it is the proof ID returned by the network prover,
// /// otherwise, it should be empty.
// proof_id: String,
// },
/// The request is successful
Success {
/// The proof of the request
proof: Proof,
},

// === Cancelled status ===
/// The request is cancelled
Cancelled,

// === Error status ===
/// The request is failed with an error
Failed {
/// The error message
error: String,
},
}

impl Status {
pub fn is_success(&self) -> bool {
matches!(self, Status::Success { .. })
}
}

#[derive(
PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, RedisValue, Getters,
)]
/// The status of a request with context
pub struct StatusWithContext {
/// The status of the request
status: Status,
/// The timestamp of the status
timestamp: DateTime<Utc>,
}

impl StatusWithContext {
pub fn new(status: Status, timestamp: DateTime<Utc>) -> Self {
Self { status, timestamp }
}

pub fn new_registered() -> Self {
Self::new(Status::Registered, chrono::Utc::now())
}

pub fn new_cancelled() -> Self {
Self::new(Status::Cancelled, chrono::Utc::now())
}

pub fn into_status(self) -> Status {
self.status
}
}

impl From<Status> for StatusWithContext {
fn from(status: Status) -> Self {
Self::new(status, chrono::Utc::now())
}
}

/// The key to identify a request in the pool
#[derive(
PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue,
)]
pub enum RequestKey {
SingleProof(SingleProofRequestKey),
Aggregation(AggregationRequestKey),
}

impl RequestKey {
pub fn proof_type(&self) -> &ProofType {
match self {
RequestKey::SingleProof(key) => &key.proof_type,
RequestKey::Aggregation(key) => &key.proof_type,
}
}
}

/// The key to identify a request in the pool
#[derive(
PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, Getters,
)]
pub struct SingleProofRequestKey {
/// The chain ID of the request
chain_id: ChainId,
/// The block number of the request
block_number: u64,
/// The block hash of the request
block_hash: B256,
/// The proof type of the request
proof_type: ProofType,
/// The prover of the request
prover_address: String,
}

impl SingleProofRequestKey {
pub fn new(
chain_id: ChainId,
block_number: u64,
block_hash: B256,
proof_type: ProofType,
prover_address: String,
) -> Self {
Self {
chain_id,
block_number,
block_hash,
proof_type,
prover_address,
}
}
}

#[derive(
PartialEq, Debug, Clone, Deserialize, Serialize, Eq, PartialOrd, Ord, Hash, RedisValue, Getters,
)]
/// The key to identify an aggregation request in the pool
pub struct AggregationRequestKey {
// TODO add chain_id
proof_type: ProofType,
block_numbers: Vec<u64>,
}

impl AggregationRequestKey {
pub fn new(proof_type: ProofType, block_numbers: Vec<u64>) -> Self {
Self {
proof_type,
block_numbers,
}
}
}

impl From<SingleProofRequestKey> for RequestKey {
fn from(key: SingleProofRequestKey) -> Self {
RequestKey::SingleProof(key)
}
}

impl From<AggregationRequestKey> for RequestKey {
fn from(key: AggregationRequestKey) -> Self {
RequestKey::Aggregation(key)
}
}

#[serde_as]
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue, Getters)]
pub struct SingleProofRequestEntity {
/// The block number for the block to generate a proof for.
block_number: u64,
/// The l1 block number of the l2 block be proposed.
l1_inclusion_block_number: u64,
/// The network to generate the proof for.
network: String,
/// The L1 network to generate the proof for.
l1_network: String,
/// Graffiti.
graffiti: B256,
/// The protocol instance data.
#[serde_as(as = "DisplayFromStr")]
prover: Address,
/// The proof type.
proof_type: ProofType,
/// Blob proof type.
blob_proof_type: BlobProofType,
#[serde(flatten)]
/// Additional prover params.
prover_args: HashMap<String, serde_json::Value>,
}

impl SingleProofRequestEntity {
pub fn new(
block_number: u64,
l1_inclusion_block_number: u64,
network: String,
l1_network: String,
graffiti: B256,
prover: Address,
proof_type: ProofType,
blob_proof_type: BlobProofType,
prover_args: HashMap<String, serde_json::Value>,
) -> Self {
Self {
block_number,
l1_inclusion_block_number,
network,
l1_network,
graffiti,
prover,
proof_type,
blob_proof_type,
prover_args,
}
}
}

#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue, Getters)]
pub struct AggregationRequestEntity {
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
aggregation_ids: Vec<u64>,
/// The block numbers and l1 inclusion block numbers for the blocks to aggregate proofs for.
proofs: Vec<Proof>,
/// The proof type.
proof_type: ProofType,
#[serde(flatten)]
/// Any additional prover params in JSON format.
prover_args: ProverSpecificOpts,
}

impl AggregationRequestEntity {
pub fn new(
aggregation_ids: Vec<u64>,
proofs: Vec<Proof>,
proof_type: ProofType,
prover_args: ProverSpecificOpts,
) -> Self {
Self {
aggregation_ids,
proofs,
proof_type,
prover_args,
}
}
}

/// The entity of a request
#[derive(PartialEq, Debug, Clone, Deserialize, Serialize, RedisValue)]
pub enum RequestEntity {
SingleProof(SingleProofRequestEntity),
Aggregation(AggregationRequestEntity),
}

impl From<SingleProofRequestEntity> for RequestEntity {
fn from(entity: SingleProofRequestEntity) -> Self {
RequestEntity::SingleProof(entity)
}
}

impl From<AggregationRequestEntity> for RequestEntity {
fn from(entity: AggregationRequestEntity) -> Self {
RequestEntity::Aggregation(entity)
}
}

// === impl Display using json_pretty ===

impl_display_using_json_pretty!(RequestKey);
impl_display_using_json_pretty!(SingleProofRequestKey);
impl_display_using_json_pretty!(AggregationRequestKey);
impl_display_using_json_pretty!(RequestEntity);
impl_display_using_json_pretty!(SingleProofRequestEntity);
impl_display_using_json_pretty!(AggregationRequestEntity);

// === impl Display for Status ===

impl std::fmt::Display for Status {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Status::Registered => write!(f, "Registered"),
Status::WorkInProgress => write!(f, "WorkInProgress"),
Status::Success { .. } => write!(f, "Success"),
Status::Cancelled => write!(f, "Cancelled"),
Status::Failed { error } => write!(f, "Failed({})", error),
}
}
}

impl std::fmt::Display for StatusWithContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.status())
}
}
27 changes: 27 additions & 0 deletions reqpool/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use raiko_lib::{proof_type::ProofType, prover::ProofKey};

use crate::{RequestKey, SingleProofRequestKey};

/// Returns the proof key corresponding to the request key.
///
/// During proving, the prover will store the network proof id into pool, which is identified by **proof key**. This
/// function is used to generate a unique proof key corresponding to the request key, so that we can store the
/// proof key into the pool.
///
/// Note that this is a hack, and it should be removed in the future.
pub fn proof_key_to_hack_request_key(proof_key: ProofKey) -> RequestKey {
let (chain_id, block_number, block_hash, proof_type) = proof_key;

// HACK: Use a special prover address as a mask, to distinguish from real
// RequestKeys
let hack_prover_address = String::from("0x1231231231231231231231231231231231231231");

SingleProofRequestKey::new(
chain_id,
block_number,
block_hash,
ProofType::try_from(proof_type).expect("unsupported proof type, it should not happen at proof_key_to_hack_request_key, please issue a bug report"),
hack_prover_address,
)
.into()
}

0 comments on commit 9a243c0

Please sign in to comment.