diff --git a/NAMESPACE b/NAMESPACE index 23deb4ea..db772722 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,6 +7,7 @@ export("%>%") export(accelerator) export(fit) export(luz_callback) +export(luz_callback_early_stopping) export(luz_callback_metrics) export(luz_callback_progress) export(luz_callback_train_valid) diff --git a/R/callbacks.R b/R/callbacks.R index 2bd33d24..e5cdb669 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -53,7 +53,7 @@ default_callbacks <- function() { #' A `luz_callback` that can be passed to [fit.luz_module_generator()]. #' @family luz_callbacks #' @export -luz_callback <- function(name, ..., private = NULL, active = NULL, parent_env = parent.frame(), +luz_callback <- function(name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL) { make_class( name = name, @@ -263,4 +263,112 @@ luz_callback_train_valid <- luz_callback( } ) +#' Early stopping callback +#' +#' Stops training when a monitored metric stops improving +#' +#' @param monitor A string in the format `_` where `` can be +#' 'train' or 'valid' and `` can be the abbreviation of any metric +#' that you are tracking during training. +#' @param min_delta Minimum improvement to reset the patience counter. +#' @param patience Number of epochs without improving until stoping training. +#' @param mode Specifies the direction that is considered an improvement. By default +#' 'min' is used. Can also be 'max' (higher is better) and 'zero' +#' (closer to zero is better). +#' @param baseline An initial value that will be used as the best seen value +#' in the begining. Model will stopm training if no better than baseline value +#' is found in the first `patience` epochs. +#' +#' @note +#' This callback adds a `on_early_stopping` callback that can be used to +#' call callbacks after as soon as the model stopped training. +#' +#' @note +#' If `verbose=TRUE` in [fit.luz_module_generator()] a message is printed when +#' early stopping. +#' +#' @returns +#' A `luz_callback` that does early stopping. +#' +#' @examples +#' cb <- luz_callback_early_stopping() +#' +#' @family luz_callbacks +#' @export +luz_callback_early_stopping <- luz_callback( + name = "early_stopping_callback", + initialize = function(monitor = "valid_loss", min_delta = 0, patience = 0, + mode="min", baseline=NULL) { + self$monitor <- monitor + self$min_delta <- min_delta + self$patience <- patience + self$mode <- mode + self$baseline <- baseline + + if (!is.null(self$baseline)) + self$current_best <- baseline + + self$patience_counter <- 0L + }, + on_fit_begin = function() { + ctx$handlers <- append(ctx$handlers, list( + early_stopping = function(err) { + ctx$call_callbacks("on_early_stopping") + invisible(NULL) + } + )) + }, + on_epoch_end = function() { + + qty <- self$find_quantity() + if (is.null(self$current_best)) + self$current_best <- qty + + if (self$compare(qty, self$current_best)) { + # means that new qty is better then previous + self$current_best <- qty + self$patience_counter <- 0L + } else { + # mean that qty did not improve + self$patience_counter <- self$patience_counter + 1L + } + + if (self$patience_counter >= self$patience) { + rlang::signal("Early stopping", class = "early_stopping") + } + + }, + on_early_stopping = function() { + inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}")) + }, + find_quantity = function() { + o <- strsplit(self$monitor, "_")[[1]] + set <- o[[1]] + qty <- o[[2]] + opt <- if (length(o) >= 3) o[[3]] else "opt" + + out <- if (qty == "loss") { + as.numeric(utils::tail(ctx$losses[[set]], 1)[[1]][[opt]]) + } else { + as.numeric(ctx$records$metrics[[set]][[qty]][[opt]]) + } + + if (length(out) != 1) + rlang::abort(glue::glue("Expected monitored metric to be length 1, got {length(out)}")) + + out + }, + # returns TRUE when the new is better then previous acording to mode + compare = function(new, old) { + out <- if (self$mode == "min") + (old - self$min_delta) > new + else if (self$mode == "max") + (new - self$min_delta) > old + else if (self$mode == "zero") + (abs(old) - self$min_delta) > abs(self$min_delta) + + as.array(out) + } +) + diff --git a/R/module.R b/R/module.R index 0635928d..ede169b3 100644 --- a/R/module.R +++ b/R/module.R @@ -193,48 +193,53 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL call_all_callbacks(ctx$callbacks, name) } - ctx$call_callbacks("on_fit_begin") + ctx$handlers <- list() - for (epoch in seq_len(ctx$epochs)) { - ctx$epoch <- epoch - ctx$iter <- 0L - ctx$call_callbacks("on_epoch_begin") + ctx$call_callbacks("on_fit_begin") + rlang::with_handlers( + !!! ctx$handlers, + .expr = { + for (epoch in seq_len(ctx$epochs)) { + ctx$epoch <- epoch + ctx$iter <- 0L + ctx$call_callbacks("on_epoch_begin") - ctx$call_callbacks("on_train_begin") + ctx$call_callbacks("on_train_begin") - coro::loop(for (batch in ctx$data) { - bind_batch_to_ctx(ctx, batch) - ctx$iter <- ctx$iter + 1L + coro::loop(for (batch in ctx$data) { + bind_batch_to_ctx(ctx, batch) + ctx$iter <- ctx$iter + 1L - ctx$call_callbacks("on_train_batch_begin") - step() - ctx$call_callbacks("on_train_batch_end") - }) + ctx$call_callbacks("on_train_batch_begin") + step() + ctx$call_callbacks("on_train_batch_end") + }) - ctx$call_callbacks("on_train_end") + ctx$call_callbacks("on_train_end") - if (!is.null(ctx$valid_data)) { + if (!is.null(ctx$valid_data)) { - ctx$call_callbacks("on_valid_begin") + ctx$call_callbacks("on_valid_begin") - ctx$iter <- 0L - torch::with_no_grad({ - coro::loop(for (batch in ctx$valid_data) { - bind_batch_to_ctx(ctx, batch) - ctx$iter <- ctx$iter + 1L + ctx$iter <- 0L + torch::with_no_grad({ + coro::loop(for (batch in ctx$valid_data) { + bind_batch_to_ctx(ctx, batch) + ctx$iter <- ctx$iter + 1L - ctx$call_callbacks("on_valid_batch_begin") - step() - ctx$call_callbacks("on_valid_batch_end") - }) - }) + ctx$call_callbacks("on_valid_batch_begin") + step() + ctx$call_callbacks("on_valid_batch_end") + }) + }) - ctx$call_callbacks("on_valid_end") + ctx$call_callbacks("on_valid_end") - } + } - ctx$call_callbacks("on_epoch_end") - } + ctx$call_callbacks("on_epoch_end") + } + }) ctx$call_callbacks("on_fit_end") structure( @@ -251,7 +256,7 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL #' @importFrom stats predict #' @export predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), - accelerator = NULL) { + accelerator = NULL) { ctx <- object$ctx @@ -274,6 +279,7 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), else stack <- pars$stack + ctx$handlers <- list() ctx$output <- list() ctx$callbacks <- initialize_callbacks(callbacks, ctx) @@ -281,13 +287,18 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), torch::with_no_grad({ ctx$call_callbacks("on_predict_begin") - coro::loop(for(batch in data) { - ctx$batch <- batch - ctx$input <- batch[[1]] - ctx$call_callbacks("on_predict_batch_begin") - ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input)) - ctx$call_callbacks("on_predict_batch_end") - }) + rlang::with_handlers( + !!! ctx$handlers, + .expr = { + coro::loop(for(batch in data) { + ctx$batch <- batch + ctx$input <- batch[[1]] + ctx$call_callbacks("on_predict_batch_begin") + ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input)) + ctx$call_callbacks("on_predict_batch_end") + }) + } + ) ctx$call_callbacks("on_predict_end") }) diff --git a/R/utils.R b/R/utils.R index 3687e711..a86acdaf 100644 --- a/R/utils.R +++ b/R/utils.R @@ -125,3 +125,5 @@ make_class <- function(name, ..., private, active, inherit, parent_env, .init_fu attr(f, "r6_class") <- r6_class f } + + diff --git a/man/ctx.Rd b/man/ctx.Rd index e1dc2229..a0b15eae 100644 --- a/man/ctx.Rd +++ b/man/ctx.Rd @@ -36,6 +36,7 @@ could potentially modify these attributes or add new ones.\tabular{ll}{ \code{metrics} \tab \code{list()} of metric objects that are \code{update}d at every \code{on_train_batch_end()} or \code{on_valid_batch_end()}. See also \code{help("luz_callback_metrics")} \cr \code{records} \tab \code{list()} recording metric values for training and validation for each epoch. See also \code{help("luz_callback_metrics")} \cr \code{losses} \tab \code{list()} tracking losses over time. See also \code{help("luz_callback_metrics")} \cr + \code{handlers} \tab A named \code{list()} of handlers that is passed to \code{rlang::with_handlers()} during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. \cr } diff --git a/man/luz_callback.Rd b/man/luz_callback.Rd index fe3e61b3..e1f66efb 100644 --- a/man/luz_callback.Rd +++ b/man/luz_callback.Rd @@ -134,6 +134,7 @@ print_callback <- luz_callback( } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_metrics}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()} diff --git a/man/luz_callback_early_stopping.Rd b/man/luz_callback_early_stopping.Rd new file mode 100644 index 00000000..ff331065 --- /dev/null +++ b/man/luz_callback_early_stopping.Rd @@ -0,0 +1,56 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callbacks.R +\name{luz_callback_early_stopping} +\alias{luz_callback_early_stopping} +\title{Early stopping callback} +\usage{ +luz_callback_early_stopping( + monitor = "valid_loss", + min_delta = 0, + patience = 0, + mode = "min", + baseline = NULL +) +} +\arguments{ +\item{monitor}{A string in the format \verb{_} where \verb{} can be +'train' or 'valid' and \verb{} can be the abbreviation of any metric +that you are tracking during training.} + +\item{min_delta}{Minimum improvement to reset the patience counter.} + +\item{patience}{Number of epochs without improving until stoping training.} + +\item{mode}{Specifies the direction that is considered an improvement. By default +'min' is used. Can also be 'max' (higher is better) and 'zero' +(closer to zero is better).} + +\item{baseline}{An initial value that will be used as the best seen value +in the begining. Model will stopm training if no better than baseline value +is found in the first \code{patience} epochs.} +} +\value{ +A \code{luz_callback} that does early stopping. +} +\description{ +Stops training when a monitored metric stops improving +} +\note{ +This callback adds a \code{on_early_stopping} callback that can be used to +call callbacks after as soon as the model stopped training. + +If \code{verbose=TRUE} in \code{\link[=fit.luz_module_generator]{fit.luz_module_generator()}} a message is printed when +early stopping. +} +\examples{ +cb <- luz_callback_early_stopping() + +} +\seealso{ +Other luz_callbacks: +\code{\link{luz_callback_metrics}()}, +\code{\link{luz_callback_progress}()}, +\code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback}()} +} +\concept{luz_callbacks} diff --git a/man/luz_callback_metrics.Rd b/man/luz_callback_metrics.Rd index 4c6db3b8..f6956105 100644 --- a/man/luz_callback_metrics.Rd +++ b/man/luz_callback_metrics.Rd @@ -26,6 +26,7 @@ used by default in \code{\link[=fit.luz_module_generator]{fit.luz_module_generat } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, \code{\link{luz_callback}()} diff --git a/man/luz_callback_progress.Rd b/man/luz_callback_progress.Rd index d211d40a..190d8d3f 100644 --- a/man/luz_callback_progress.Rd +++ b/man/luz_callback_progress.Rd @@ -17,6 +17,7 @@ Printing can be disabled by passing \code{verbose=FALSE} to \code{\link[=fit.luz } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_metrics}()}, \code{\link{luz_callback_train_valid}()}, \code{\link{luz_callback}()} diff --git a/man/luz_callback_train_valid.Rd b/man/luz_callback_train_valid.Rd index 825ed977..201fad7b 100644 --- a/man/luz_callback_train_valid.Rd +++ b/man/luz_callback_train_valid.Rd @@ -26,6 +26,7 @@ used by default in \code{\link[=fit.luz_module_generator]{fit.luz_module_generat } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_metrics}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback}()} diff --git a/man/rmd/ctx.Rmd b/man/rmd/ctx.Rmd index e824f3d9..bc21b403 100644 --- a/man/rmd/ctx.Rmd +++ b/man/rmd/ctx.Rmd @@ -49,5 +49,7 @@ The `ctx` object is used in luz to share information between the training loop a +------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | `losses` | `list()` tracking losses over time. See also `help("luz_callback_metrics")` | +------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| `handlers` | A named `list()` of handlers that is passed to `rlang::with_handlers()` during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. | ++------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ : Context attributes diff --git a/tests/testthat/_snaps/callbacks.md b/tests/testthat/_snaps/callbacks.md new file mode 100644 index 00000000..4de71b2f --- /dev/null +++ b/tests/testthat/_snaps/callbacks.md @@ -0,0 +1,32 @@ +# early stopping + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, + verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( + monitor = "train_loss", patience = 1))) + }) + Message + Train metrics: Loss: 1.5301 + Early stopping at epoch 1 of 25 + +--- + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, + verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( + monitor = "train_loss", patience = 5, baseline = 0.001))) + }) + Message + Train metrics: Loss: 1.4807 + Epoch 2/25 + Train metrics: Loss: 1.3641 + Epoch 3/25 + Train metrics: Loss: 1.2073 + Epoch 4/25 + Train metrics: Loss: 1.2524 + Epoch 5/25 + Train metrics: Loss: 1.1891 + Early stopping at epoch 5 of 25 + diff --git a/tests/testthat/test-callbacks.R b/tests/testthat/test-callbacks.R new file mode 100644 index 00000000..fbf0b409 --- /dev/null +++ b/tests/testthat/test-callbacks.R @@ -0,0 +1,48 @@ +test_that("early stopping", { + torch::torch_manual_seed(1) + set.seed(1) + + model <- get_model() + dl <- get_dl() + + mod <- model %>% + setup( + loss = torch::nn_mse_loss(), + optimizer = optim_adam, + ) + + expect_snapshot({ + expect_message({ + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = TRUE, epochs = 25, callbacks = list( + luz_callback_early_stopping(monitor = "train_loss", patience = 1) + )) + }) + }) + + expect_snapshot({ + expect_message({ + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = TRUE, epochs = 25, callbacks = list( + luz_callback_early_stopping(monitor = "train_loss", patience = 5, + baseline = 0.001) + )) + }) + }) + + x <- 0 + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = FALSE, epochs = 25, callbacks = list( + luz_callback_early_stopping(monitor = "train_loss", patience = 5, + baseline = 0.001), + luz_callback(on_early_stopping = function() { + x <<- 1 + })() + )) + + expect_equal(x, 1) +}) +