From 62b0e25ab815f9552dc9ccafb568dd0c9aacb870 Mon Sep 17 00:00:00 2001 From: Morgan Kain Date: Fri, 19 Jan 2024 17:22:40 -0500 Subject: [PATCH] some modeling progress on the small bat dataset --- dataset_exploration/cluster_regression.stan | 122 +++++++++ dataset_exploration/small_bat_dataset.R | 278 ++++++++++++++++++-- 2 files changed, 382 insertions(+), 18 deletions(-) create mode 100644 dataset_exploration/cluster_regression.stan diff --git a/dataset_exploration/cluster_regression.stan b/dataset_exploration/cluster_regression.stan new file mode 100644 index 0000000..900e2e5 --- /dev/null +++ b/dataset_exploration/cluster_regression.stan @@ -0,0 +1,122 @@ + +data { + + int N; + int N_plate; + int N_loc; + int N_cov; + int N_spec; + + array[N] real y; + array[N] int plate; + array[N] int loc; + array[N] int spec; + + matrix[N, N_cov] mm; + +} + +parameters { + + real mu_base; + real beta_base; + real sigma_base; + + vector[N_cov] beta_diff; + real mu_diff; + real sigma_diff; + + real mu_plate_sd; + vector[N_plate] mu_plate_eps; + + real mu_loc_sd; + vector[N_loc] mu_loc_eps; + + real beta_loc_sd; + vector[N_loc] beta_loc_eps; + +} + +transformed parameters { + + matrix[2, N] mu; + vector[2] sigma; + vector[N] beta_vec; + + vector[N_plate] mu_plate_dev; + vector[N_loc] mu_loc_dev; + vector[N_loc] beta_loc_dev; + + sigma[1] = sigma_base; + sigma[2] = sigma_base + sigma_diff; + + for (n in 1:N_plate) { + mu_plate_dev[n] = mu_plate_sd * mu_plate_eps[n]; + } + + for (n in 1:N_loc) { + mu_loc_dev[n] = mu_loc_sd * mu_loc_eps[n]; + beta_loc_dev[n] = beta_loc_sd * beta_loc_eps[n]; + } + + for (n in 1:N) { + mu[1, n] = mu_base + mu_plate_dev[plate[n]] + mu_loc_dev[loc[n]]; + mu[2, n] = mu[1, n] + mu_diff; + } + + for (n in 1:N) { + beta_vec[n] = inv_logit(beta_base + mm[n, ] * beta_diff + beta_loc_dev[loc[n]]); + } + +} + +model { + +// --- Priors --- // + + mu_base ~ normal(0, 2); + mu_diff ~ normal(0, 2); + + sigma_base ~ normal(0, 2); + sigma_diff ~ normal(0, 2); + + beta_base ~ normal(0, 3); + beta_diff ~ normal(0, 2); + + mu_plate_sd ~ inv_gamma(8, 15); + mu_loc_sd ~ inv_gamma(8, 15); + beta_loc_sd ~ inv_gamma(8, 15); + + mu_plate_eps ~ normal(0, 3); + mu_loc_eps ~ normal(0, 3); + beta_loc_eps ~ normal(0, 3); + + +// --- Model --- // + + for (n in 1:N) + target += log_mix(beta_vec[n], + normal_lpdf(y[n] | mu[2, n], sigma[2]), + normal_lpdf(y[n] | mu[1, n], sigma[1])); + +} + +generated quantities { + + matrix[2, N] membership_l; matrix[2, N] membership_p; matrix[N_loc, N_spec] pop_sero = rep_matrix(0, N_loc, N_spec); array[N] int ind_sero; matrix[2, N] log_beta; + vector[2] temp_beta_vec_1; + vector[2] temp_beta_vec_2; + real max_val = -500; + for (n in 1:N) { log_beta[1, n] = log(beta_vec[n]); log_beta[2, n] = log(1 - beta_vec[n]); log_beta[1, n] += normal_lpdf(y[n] | mu[2, n], sigma[2]); log_beta[2, n] += normal_lpdf(y[n] | mu[1, n], sigma[1]); + + temp_beta_vec_1[1] = log_beta[1, n]; + temp_beta_vec_1[2] = max_val; + + temp_beta_vec_2[1] = log_beta[2, n]; + temp_beta_vec_2[2] = max_val; + membership_l[1, n] = exp(max(temp_beta_vec_1)); membership_l[2, n] = exp(max(temp_beta_vec_2)); membership_p[1, n] = membership_l[1, n] / (membership_l[1, n] + membership_l[2, n]); membership_p[2, n] = membership_l[2, n] / (membership_l[1, n] + membership_l[2, n]); ind_sero[n] = binomial_rng(1, membership_p[1, n]); + + pop_sero[loc[n], spec[n]] = pop_sero[loc[n], spec[n]] + ind_sero[n]; } + +} + diff --git a/dataset_exploration/small_bat_dataset.R b/dataset_exploration/small_bat_dataset.R index d975c8c..e81e147 100644 --- a/dataset_exploration/small_bat_dataset.R +++ b/dataset_exploration/small_bat_dataset.R @@ -12,20 +12,20 @@ lapply(needed_packages, require, character.only = TRUE) %>% unlist() theme_set(theme_bw()) suppressWarnings( theme_update( - axis.text.x = element_text(size = 10) - , axis.text.y = element_text(size = 10) - , axis.title.x = element_text(size = 16) - , axis.title.y = element_text(size = 16) - , legend.title = element_text(size = 12) + axis.text.x = element_text(size = 10) + , axis.text.y = element_text(size = 10) + , axis.title.x = element_text(size = 16) + , axis.title.y = element_text(size = 16) + , legend.title = element_text(size = 12) , panel.grid.major = element_blank() , panel.grid.minor = element_blank() , strip.background = element_blank() - , panel.margin = unit(0, "lines") - , legend.key.size = unit(.55, "cm") - , legend.key = element_rect(fill = "white") - , panel.margin.y = unit(0.5, "lines") - , panel.border = element_rect(colour = "black", fill = NA, size = 1) - , strip.text.x = element_text(size = 16, colour = "black", face = "bold")) + , panel.margin = unit(0, "lines") + , legend.key.size = unit(.55, "cm") + , legend.key = element_rect(fill = "white") + , panel.margin.y = unit(0.5, "lines") + , panel.border = element_rect(colour = "black", fill = NA, size = 1) + , strip.text.x = element_text(size = 16, colour = "black", face = "bold")) ) ## plotting function @@ -102,12 +102,33 @@ mfi_plot <- function( } ## Load and some quick reorganization -small_bat <- read.csv("/data/small_bat.csv", comment.char="#") %>% filter(!is.na(SLNO)) %>% dplyr::select(-SLNO) %>% pivot_longer( - . -, -c("merge_identity", "date_sampled", "species", "location", "Sex", "Age", "BCS", "PlateNo") -, names_to = "Virus" -, values_to = "MFI" -) +small_bat <- read.csv("../data/small_bat.csv", comment.char="#") %>% + filter(!is.na(SLNO)) %>% + dplyr::select(-SLNO) %>% + pivot_longer( + . + , -c("merge_identity", "date_sampled", "species", "location", "Sex", "Age", "BCS", "PlateNo") + , names_to = "Virus" + , values_to = "MFI" + ) %>% + mutate( + location = plyr::mapvalues( + location + ## Clean up some names + , from = c( + " Sadar,Faridpur", "Mogalhat,Lalmonirhat", "Sreemangal", " sadar, Khagrachari", "Gangachara, Rangpur" + , "Kakoni,Lalmonirhat", "Customs house,Lalmonirhat", " kaijuri, Rajbari", "Piljong,Bagerhat", "Abhaynagar, Jessore" + , "Rajbari", " kanaipur, Faripur") + , to = c( + "Sadar, Faridpur", "Mogalhat, Lalmonirhat", "Sreemangal", "Sadar, Khagrachari", "Gangachara, Rangpur" + , "Kakoni, Lalmonirhat", "Customs house, Lalmonirhat", "Kaijuri, Rajbari", "Piljong, Bagerhat", "Abhaynagar, Jessore" + , "Rajbari", "Kanaipur, Faripur") + ) + ## Join redundant male sex labels + , Sex = plyr::mapvalues(Sex, from = c("male ", "Male "), to = c("Male", "Male")) + ) %>% + ## For consistency + rename(sex = Sex, age = Age, bcs = BCS) mfi_plot( viruses = unique(small_bat$Virus) @@ -124,6 +145,227 @@ mfi_plot( , group_BY2 = "PlateNo" , colour_BY = "species" , facet_BY = NA -, logMFI = TRUE +, logMFI = FALSE , include.R = FALSE ) + +######################### +### ---- Fitting ---- ### + +## Potential covariates: + ## Location + ## Species + ## Sex + ## Age + ## BCS (body condition?) + ## Plate number + +small_bat.t <- small_bat %>% + ## For now, just filos + #filter(Virus %in% c("MarVGP", "EboVGP")) %>% + #filter(Virus %in% "MarVGP") %>% + filter(Virus %in% "EboVGP") %>% + mutate( + location_n = as.numeric(as.factor(location)) + , spec_n = as.numeric(as.factor(species)) + ) %>% + mutate( + entry = seq(n()), .before = 1 + ) %>% mutate( + MFI = log10(MFI) + ) %>% droplevels() + +cov_mat <- model.matrix(~species+sex, small_bat.t)[, ] + +conflicted::conflicts_prefer(rstan::lookup) + +stan_fit <- stan( + file = "cluster_regression.stan" + , data = list( + N = nrow(small_bat.t) + , N_plate = max(small_bat.t$PlateNo) + , plate = small_bat.t$PlateNo + , N_loc = max(small_bat.t$location_n) + , loc = small_bat.t$location_n + , N_spec = max(small_bat.t$spec_n) + , spec = small_bat.t$spec_n + , y = small_bat.t$MFI + , mm = cov_mat + , N_cov = ncol(cov_mat) + ) + , iter = 2000#6000 + , warmup = 500#2000 + , thin = 1 + , chains = 3 + , cores = 3 + , seed = 10001 + , refresh = 200 + , control = list(adapt_delta = 0.92, max_treedepth = 13) +) + +## Extract samples +stan.fit.samples <- stan_fit %>% extract() + +membership_p.mar <- stan.fit.samples$membership_p +membership_p.ebo <- stan.fit.samples$membership_p + +small_bat.t.mar <- small_bat.t +small_bat.t.ebo <- small_bat.t + +## Convert pop positive count to percentage +#tot_n <- small_bat.t %>% group_by(location_n, spec_n) %>% summarize(n_entry = n()) %>% ungroup() +#for (i in 1:nrow(tot_n)) { +# stan.fit.samples$pop_sero[, tot_n$location_n[i], tot_n$spec_n[i]] <- stan.fit.samples$pop_sero[, tot_n$location_n[i], tot_n$spec_n[i]] / tot_n$n_entry[i] +#} + +pop_sero.mar <- membership_p.mar %>% + reshape2::melt() %>% + filter(Var2 == 1) %>% + dplyr::select(-Var2) %>% + mutate(assign_pos = ifelse(value > 0.975, 1, 0)) %>% + rename(entry = Var3) %>% + left_join(., small_bat.t, by = "entry") %>% + group_by(species, location, iterations) %>% + summarize(pos_perc = sum(assign_pos) / n()) %>% + ungroup() %>% + group_by(species, location) %>% + summarize( + prob_inf_lwr = quantile(pos_perc, 0.025) + , prob_inf_lwr_n = quantile(pos_perc, 0.200) + , prob_inf_mid = quantile(pos_perc, 0.500) + , prob_inf_upr_n = quantile(pos_perc, 0.800) + , prob_inf_upr = quantile(pos_perc, 0.975) + ) %>% mutate( + Virus = "MarVGP" + # Virus = "EboVGP" + , .before = 1 + ) + +pop_sero.b <- rbind( + pop_sero.mar +, pop_sero.ebo +) + +prob_inf <- stan.fit.samples$membership_p %>% + reshape2::melt() %>% + filter(Var2 == 1) %>% + group_by(Var3) %>% + summarize( + prob_inf_lwr = quantile(value, 0.025) + , prob_inf_lwr_n = quantile(value, 0.200) + , prob_inf_mid = quantile(value, 0.500) + , prob_inf_upr_n = quantile(value, 0.800) + , prob_inf_upr = quantile(value, 0.975) + ) %>% + rename(entry = Var3) + +## Summarize stan estimates and add to data +small_bat.t %<>% + left_join(., prob_inf, by = "entry") %>% + mutate( + assigned_positive = ifelse(prob_inf_lwr > 0.975, 1, 0) %>% as.factor() + , .after = MFI + ) %>% dplyr::select( + -contains("prob") + ) + +small_bat.t.MAR <- small_bat.t +small_bat.t.EBO <- small_bat.t + +small_bat.t.b <- rbind( + small_bat.t.MAR +, small_bat.t.EBO +) + +small_bat.t.b %>% { + ggplot(., aes(MFI, assigned_positive)) + geom_jitter(aes(colour = Virus)) +} + +## Custom plot (can update sometime soonish to use function) +scale_point_size_continuous <- function(name = ggplot2::waiver(), breaks = ggplot2::waiver(), labels = ggplot2::waiver(), + limits = NULL, range = c(1, 6), + trans = "identity", guide = "legend", aesthetics = "point_size") { + ggplot2::continuous_scale(aesthetics, "area", scales::area_pal(range), name = name, + breaks = breaks, labels = labels, limits = limits, trans = trans, + guide = guide) +} +scale_point_shape <- function(..., solid = TRUE, aesthetics = "point_shape") { + discrete_scale(aesthetics, "shape_d", scales::shape_pal(solid), ...) +} + +small_bat.t.b[small_bat.t.b$entry == 350, ]$assigned_positive <- as.factor(c(0, 0)) + +gg1 <- small_bat.t.b %>% + mutate(ps = as.numeric(assigned_positive)) %>% + rename(Seropostatus= assigned_positive) %>% + mutate(Seropostatus = plyr::mapvalues( + Seropostatus + , from = c(0, 1) + , to = c("Negative", "Positive") + )) %>% { + ggplot(., aes( + x = MFI + , y = location + , fill = species + , group = interaction(location, species) + , point_shape = Seropostatus + , point_size = ps + ) + ) + + geom_density_ridges( + jittered_points = TRUE + , point_alpha = 0.6 + , alpha = 0.5 + # , point_size = 1.75 + , size = .75 + , quantile_lines = FALSE + , quantiles = 3 + ) + + scale_fill_brewer( + palette = "Dark2" + , name = "Species" + ) + + scale_point_size_continuous( + range = c(1, 2) + , guide = "none" + ) + + ylab("Location") + + xlab("Log10(MFI)") + + theme( + axis.text.y = element_text(size = 10) + , axis.text.x = element_text(size = 10) + ) + facet_wrap(~Virus, scales = "free") + + scale_x_log10() +} + +gg1 + +small_bat.t.b %>% + group_by(location, species) %>% + summarize( + tot_pos = (as.numeric(assigned_positive) - 1) %>% sum() + , tot_n = n() + ) %>% write.csv( + "positives.csv" + ) + +pop_sero.b %>% + filter(prob_inf_lwr > 0) %>% + mutate(y_ax = interaction(species, location, sep = " -- ")) %>% { + ggplot(., aes(prob_inf_mid, location)) + + geom_errorbarh(aes(xmin = prob_inf_lwr, xmax = prob_inf_upr, colour = species) + , height = 0.3, linewidth = 0.75 + , position = position_dodge(0.5)) + + geom_errorbarh(aes(xmin = prob_inf_lwr_n, xmax = prob_inf_upr_n, colour = species) + , height = 0, linewidth = 1.5 + , position = position_dodge(0.5)) + + geom_point(aes(colour = species), size = 2 + , position = position_dodge(0.5)) + + scale_colour_brewer(palette = "Dark2") + + xlab("Population Seropositivity") + + ylab("Location") + + theme(panel.spacing = unit(1, "lines")) + + facet_wrap(~Virus) +} + +