From 80c674c4da83f178d9a67b3f6241628b80328ba7 Mon Sep 17 00:00:00 2001 From: amir Date: Thu, 5 Dec 2024 04:32:13 -0500 Subject: [PATCH] Add post_replies --- benches/streams_benches/user.rs | 20 ++++++++++++---- src/models/user/stream.rs | 42 ++++++++++++++++++++++++++++++++- src/routes/v0/stream/users.rs | 6 +++++ 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/benches/streams_benches/user.rs b/benches/streams_benches/user.rs index fa3d0e1e..a9e32854 100644 --- a/benches/streams_benches/user.rs +++ b/benches/streams_benches/user.rs @@ -24,6 +24,8 @@ pub fn bench_stream_following(c: &mut Criterion) { Some(20), UserStreamSource::Pioneers, None, + None, + None, ) .await .unwrap(); @@ -50,6 +52,8 @@ pub fn bench_stream_most_followed(c: &mut Criterion) { Some(20), UserStreamSource::MostFollowed, None, + None, + None, ) .await .unwrap(); @@ -94,10 +98,18 @@ pub fn bench_stream_pioneers(c: &mut Criterion) { c.bench_function("stream_pioneers", |b| { b.to_async(&rt).iter(|| async { - let user_stream = - UserStream::get_by_id(None, None, None, Some(20), UserStreamSource::Pioneers, None) - .await - .unwrap(); + let user_stream = UserStream::get_by_id( + None, + None, + None, + Some(20), + UserStreamSource::Pioneers, + None, + None, + None, + ) + .await + .unwrap(); criterion::black_box(user_stream); }); }); diff --git a/src/models/user/stream.rs b/src/models/user/stream.rs index 39d84002..d475396f 100644 --- a/src/models/user/stream.rs +++ b/src/models/user/stream.rs @@ -1,5 +1,8 @@ +use std::collections::HashSet; + use super::{Muted, UserCounts, UserSearch, UserView}; use crate::models::follow::{Followers, Following, Friends, UserFollows}; +use crate::models::post::{PostStream, POST_REPLIES_PER_POST_KEY_PARTS}; use crate::types::DynError; use crate::{db::kv::index::sorted_sets::SortOrder, RedisOps}; use crate::{get_neo4j_graph, queries}; @@ -23,6 +26,7 @@ pub enum UserStreamSource { MostFollowed, Pioneers, Recommended, + PostReplies, } #[derive(Serialize, Deserialize, ToSchema, Default)] @@ -37,9 +41,13 @@ impl UserStream { skip: Option, limit: Option, source: UserStreamSource, + author_id: Option, + post_id: Option, depth: Option, ) -> Result, DynError> { - let user_ids = Self::get_user_list_from_source(user_id, source, skip, limit).await?; + let user_ids = + Self::get_user_list_from_source(user_id, source, author_id, post_id, skip, limit) + .await?; match user_ids { Some(users) => Self::from_listed_user_ids(&users, viewer_id, depth).await, None => Ok(None), @@ -191,6 +199,8 @@ impl UserStream { pub async fn get_user_list_from_source( user_id: Option<&str>, source: UserStreamSource, + author_id: Option, + post_id: Option, skip: Option, limit: Option, ) -> Result>, DynError> { @@ -257,6 +267,36 @@ impl UserStream { ) .await? } + UserStreamSource::PostReplies => { + let post_id = post_id.unwrap(); + let author_id = author_id.unwrap(); + let key_parts = [ + &POST_REPLIES_PER_POST_KEY_PARTS[..], + &[author_id.as_str(), post_id.as_str()], + ] + .concat(); + let replies = PostStream::try_from_index_sorted_set( + &key_parts, + None, + None, + None, + None, + SortOrder::Descending, + None, + ) + .await?; + let unique_user_ids: HashSet = replies + .map(|replies| { + replies + .into_iter() + .map(|reply| reply.0.split(":").next().unwrap().to_string()) + .collect::>() + }) + .into_iter() + .flatten() + .collect(); + Some(unique_user_ids.into_iter().collect()) + } }; Ok(user_ids) } diff --git a/src/routes/v0/stream/users.rs b/src/routes/v0/stream/users.rs index 389aefe1..2c2831de 100644 --- a/src/routes/v0/stream/users.rs +++ b/src/routes/v0/stream/users.rs @@ -17,6 +17,8 @@ pub struct UserStreamQuery { skip: Option, limit: Option, source: Option, + author_id: Option, + post_id: Option, depth: Option, } @@ -31,6 +33,8 @@ pub struct UserStreamQuery { ("skip" = Option, Query, description = "Skip N followers"), ("limit" = Option, Query, description = "Retrieve N followers"), ("source" = Option, Query, description = "Source of users for the stream."), + ("author_id" = Option, Query, description = "Author id when source is 'post_replies'"), + ("post_id" = Option, Query, description = "Post id when source is 'post_replies'"), ("depth" = Option, Query, description = "User trusted network depth, user following users distance. Numbers bigger than 4, will be ignored") ), responses( @@ -92,6 +96,8 @@ pub async fn stream_users_handler( Some(skip), Some(limit), source.clone(), + query.author_id, + query.post_id, query.depth, ) .await