Skip to content

Commit

Permalink
Fix optimise_random_prevalence()
Browse files Browse the repository at this point in the history
Now does something sensible for the case where correlation = NA (namely fixes the number of sampling periods to 1, and optimises only on the pooling strategy)
  • Loading branch information
AngusMcLure committed Feb 1, 2024
1 parent 151a021 commit 46b9b64
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 22 deletions.
2 changes: 1 addition & 1 deletion R/fisher_information.R
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ fi_pool_random <- function(catch_dist,
prevalence,
sensitivity,
specificity,
max_iter = 1000,
max_iter = 10000,
rel_tol = 1e-6){

#Calculates Fisher information (FI) for an unknown/random catch by taking
Expand Down
59 changes: 39 additions & 20 deletions R/optimise_prevalence.R
Original file line number Diff line number Diff line change
Expand Up @@ -283,21 +283,37 @@ optimise_random_prevalence <- function(catch_mean, catch_variance,
verbose = FALSE) {

strat_par_names <- names(formals(pool_strat_family))
pars <- as.list(rep(1, 1 + length(strat_par_names)))
names(pars) <- c(".periods", strat_par_names)

if(is.na(correlation)){
if(is.na(correlation)){ #If correlation is NA, assumes no correlation. However, sampling everything from the one site is optimal (i.e. periods = Inf). Consequently, we assume period = 1

pars <- as.list(rep(1, length(strat_par_names)))
names(pars) <- strat_par_names
catch <- nb_catch(catch_mean, catch_variance)
f <- function(prm){
cost_fi_random(nb_catch(catch_mean * prm$.periods, catch_variance * prm$.periods),
do.call(pool_strat_family, prm[strat_par_names]),
cost_fi_random(catch,
do.call(pool_strat_family, prm),
prevalence, sensitivity,specificity,
cost_unit, cost_pool, prm$.periods * cost_period)
cost_unit, cost_pool, cost_period)
}

opt <- optim_local_int(pars,f,verbose = verbose)
optpars <- optim_local_int(pars,f,verbose = verbose, max_iter = 100)


opt <- list(periods = NA,
cost = optpars$val,
catch = list(mean = catch_mean,
variance = catch_variance,
distribution = catch),
pool_strat = do.call(pool_strat_family, optpars$par),
pool_strat_pars = optpars$par)


}else{

pars <- as.list(rep(1, 1 + length(strat_par_names)))
names(pars) <- c(".periods", strat_par_names)

f <- function(prm){

cost_fi_cluster_random(nb_catch(catch_mean * prm$.periods, catch_variance * prm$.periods),
do.call(pool_strat_family, prm[strat_par_names]),
prevalence, correlation,
Expand All @@ -308,16 +324,18 @@ optimise_random_prevalence <- function(catch_mean, catch_variance,

optpars <- optim_local_int(pars,f, verbose = verbose)


opt <- list(periods = optpars$par$.periods,
cost = optpars$val,
catch = list(mean = optpars$par$.periods * catch_mean,
variance = optpars$par$.periods * catch_variance,
distribution = nb_catch(optpars$par$.periods * catch_mean,
optpars$par$.periods * catch_variance)),
pool_strat = do.call(pool_strat_family, optpars$par[-1]),
pool_strat_pars = optpars$par[-1])

}

opt <- list(periods = optpars$par$.periods,
cost = optpars$val,
catch = list(mean = optpars$par$.periods * catch_mean,
variance = optpars$par$.periods * catch_variance,
distribution = nb_catch(optpars$par$.periods * catch_mean,
optpars$par$.periods * catch_variance)),
pool_strat = do.call(pool_strat_family, optpars$par[-1]),
pool_strat_pars = optpars$par[-1])

opt
}

Expand All @@ -343,17 +361,18 @@ optim_local_int <- function(par, fn,
val <- c()
for(i in 1:nrow(search)){
if(verbose){print(paste0('Calculating for paramter set #',i,': ', paste(unlist(search[i,]), collapse = ", ") ))}
val[i] <- fn(as.list(search[i,]))
val[i] <- fn(as.list(search[i,,drop = FALSE]))
}

min_val <- min(val)
if(opt_val < min_val){
if(opt_val <= min_val){
break
}else{
opt_val <- min_val
opt_par <- unlist(search[which.min(val),])
opt_par <- unlist(search[which.min(val),,drop=FALSE])
}
}
if(iter == max_iter){warning('Local integer search reached max_iter and may have terminated early. Consider increasing max_iter')}
return(list(val = opt_val, par = as.list(opt_par)))
}

Expand Down
4 changes: 3 additions & 1 deletion R/util.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ mu_sigma_linknorm <- function(.mean, .var,link,invlink){


ev <- function(fn, distr,
max_iter = 1000,
max_iter = 10000,
rel_tol = 1e-6){

#Helper function for calculating the expected value of a function with respect
Expand Down Expand Up @@ -81,6 +81,8 @@ ev <- function(fn, distr,
terminate <- TRUE
}
if(iter == max_iter){
print(rel_incr)
print(iter)
terminate <- TRUE
warning('Reached max_iter without converging. Increase max_iter')
plot(xs, E_incr)
Expand Down

0 comments on commit 46b9b64

Please sign in to comment.