From 12e35657b87489cda31d85a9ff78b6e40be1e077 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 11 May 2021 17:54:28 -0300 Subject: [PATCH 1/7] start adding support for early stopping --- R/callbacks.R | 76 ++++++++++++++++++++++++++++ R/module.R | 87 +++++++++++++++++++-------------- R/utils.R | 2 + tests/testthat/test-callbacks.R | 17 +++++++ 4 files changed, 144 insertions(+), 38 deletions(-) create mode 100644 tests/testthat/test-callbacks.R diff --git a/R/callbacks.R b/R/callbacks.R index 2bd33d24..93e45d5e 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -263,4 +263,80 @@ luz_callback_train_valid <- luz_callback( } ) +luz_callback_early_stopping <- luz_callback( + name = "early_stopping_callback", + initialize = function(monitor = "valid_loss", min_delta = 0, patience = 0, + verbose = FALSE, mode="min", baseline=NULL) { + self$monitor <- monitor + self$min_delta <- 0 + self$patience <- 0 + self$verbose <- verbose + self$mode <- mode + self$baseline <- baseline + + if (!is.null(self$baseline)) + private$current_best <- baseline + }, + on_fit_begin = function() { + ctx$handlers <- append(ctx$handlers, list( + early_stopping = function(err) { + ctx$call_all_callbacks("on_early_stopping") + invisible(NULL) + } + )) + }, + on_epoch_end = function() { + + qty <- self$find_quantity() + if (is.null(self$current_best)) + self$current_best <- qty + + if (self$compare(qty, self$current_best)) { + # means that new qty is better then previous + self$current_best <- qty + self$patience_counter <- 0L + } else { + # mean that qty did not improve + self$patience_counter <- self$patience_counter + 1L + } + + if (self$patience_counter > self$patience) { + rlang::signal("Early stopping", class = "early_stopping") + } + + }, + on_early_stopping = function() { + inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}")) + }, + find_quantity = function() { + o <- strsplit(self$monitor, "_")[[1]] + set <- o[[1]] + qty <- o[[2]] + opt <- if (length(o) >= 3) o[[3]] else "opt" + + out <- if (qty == "loss") { + as.numeric(utils::tail(ctx$losses[[set]], 1)[[1]][[opt]]) + } else { + as.numeric(ctx$records$metrics[[set]][[qty]][[opt]]) + } + + browser() + + if (length(out) != 1) + rlang::abort(glue::glue("Expected monitored metric to be length 1, got {length(out)}")) + + out + }, + compare = function(x, y) { + out <- if (self$mode == "min") + x < y + else if (self$mode == "max") + x > y + else if (self$mode == "zero") + abs(x) < abs(y) + + as.numeric(out) + } +) + diff --git a/R/module.R b/R/module.R index 0635928d..ede169b3 100644 --- a/R/module.R +++ b/R/module.R @@ -193,48 +193,53 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL call_all_callbacks(ctx$callbacks, name) } - ctx$call_callbacks("on_fit_begin") + ctx$handlers <- list() - for (epoch in seq_len(ctx$epochs)) { - ctx$epoch <- epoch - ctx$iter <- 0L - ctx$call_callbacks("on_epoch_begin") + ctx$call_callbacks("on_fit_begin") + rlang::with_handlers( + !!! ctx$handlers, + .expr = { + for (epoch in seq_len(ctx$epochs)) { + ctx$epoch <- epoch + ctx$iter <- 0L + ctx$call_callbacks("on_epoch_begin") - ctx$call_callbacks("on_train_begin") + ctx$call_callbacks("on_train_begin") - coro::loop(for (batch in ctx$data) { - bind_batch_to_ctx(ctx, batch) - ctx$iter <- ctx$iter + 1L + coro::loop(for (batch in ctx$data) { + bind_batch_to_ctx(ctx, batch) + ctx$iter <- ctx$iter + 1L - ctx$call_callbacks("on_train_batch_begin") - step() - ctx$call_callbacks("on_train_batch_end") - }) + ctx$call_callbacks("on_train_batch_begin") + step() + ctx$call_callbacks("on_train_batch_end") + }) - ctx$call_callbacks("on_train_end") + ctx$call_callbacks("on_train_end") - if (!is.null(ctx$valid_data)) { + if (!is.null(ctx$valid_data)) { - ctx$call_callbacks("on_valid_begin") + ctx$call_callbacks("on_valid_begin") - ctx$iter <- 0L - torch::with_no_grad({ - coro::loop(for (batch in ctx$valid_data) { - bind_batch_to_ctx(ctx, batch) - ctx$iter <- ctx$iter + 1L + ctx$iter <- 0L + torch::with_no_grad({ + coro::loop(for (batch in ctx$valid_data) { + bind_batch_to_ctx(ctx, batch) + ctx$iter <- ctx$iter + 1L - ctx$call_callbacks("on_valid_batch_begin") - step() - ctx$call_callbacks("on_valid_batch_end") - }) - }) + ctx$call_callbacks("on_valid_batch_begin") + step() + ctx$call_callbacks("on_valid_batch_end") + }) + }) - ctx$call_callbacks("on_valid_end") + ctx$call_callbacks("on_valid_end") - } + } - ctx$call_callbacks("on_epoch_end") - } + ctx$call_callbacks("on_epoch_end") + } + }) ctx$call_callbacks("on_fit_end") structure( @@ -251,7 +256,7 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL #' @importFrom stats predict #' @export predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), - accelerator = NULL) { + accelerator = NULL) { ctx <- object$ctx @@ -274,6 +279,7 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), else stack <- pars$stack + ctx$handlers <- list() ctx$output <- list() ctx$callbacks <- initialize_callbacks(callbacks, ctx) @@ -281,13 +287,18 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(), torch::with_no_grad({ ctx$call_callbacks("on_predict_begin") - coro::loop(for(batch in data) { - ctx$batch <- batch - ctx$input <- batch[[1]] - ctx$call_callbacks("on_predict_batch_begin") - ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input)) - ctx$call_callbacks("on_predict_batch_end") - }) + rlang::with_handlers( + !!! ctx$handlers, + .expr = { + coro::loop(for(batch in data) { + ctx$batch <- batch + ctx$input <- batch[[1]] + ctx$call_callbacks("on_predict_batch_begin") + ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input)) + ctx$call_callbacks("on_predict_batch_end") + }) + } + ) ctx$call_callbacks("on_predict_end") }) diff --git a/R/utils.R b/R/utils.R index 3687e711..a86acdaf 100644 --- a/R/utils.R +++ b/R/utils.R @@ -125,3 +125,5 @@ make_class <- function(name, ..., private, active, inherit, parent_env, .init_fu attr(f, "r6_class") <- r6_class f } + + diff --git a/tests/testthat/test-callbacks.R b/tests/testthat/test-callbacks.R new file mode 100644 index 00000000..d23a0e91 --- /dev/null +++ b/tests/testthat/test-callbacks.R @@ -0,0 +1,17 @@ +test_that("early stopping", { + model <- get_model() + dl <- get_dl() + + mod <- model %>% + setup( + loss = torch::nn_mse_loss(), + optimizer = optim_adam, + ) + + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = TRUE, callbacks = list( + luz_callback_early_stopping(monitor = "train_loss", patience = 1) + )) + +}) From 6aa57d5a73e60f5a239022f4db03455efc18ee1d Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 11 May 2021 20:03:22 -0300 Subject: [PATCH 2/7] add early stopping callback --- NAMESPACE | 1 + R/callbacks.R | 40 +++++++++++++++++++---- man/ctx.Rd | 1 + man/luz_callback.Rd | 1 + man/luz_callback_early_stopping.Rd | 52 ++++++++++++++++++++++++++++++ man/luz_callback_metrics.Rd | 1 + man/luz_callback_progress.Rd | 1 + man/luz_callback_train_valid.Rd | 1 + man/rmd/ctx.Rmd | 2 ++ tests/testthat/_snaps/callbacks.md | 17 ++++++++++ tests/testthat/test-callbacks.R | 17 +++++++--- 11 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 man/luz_callback_early_stopping.Rd create mode 100644 tests/testthat/_snaps/callbacks.md diff --git a/NAMESPACE b/NAMESPACE index 23deb4ea..db772722 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -7,6 +7,7 @@ export("%>%") export(accelerator) export(fit) export(luz_callback) +export(luz_callback_early_stopping) export(luz_callback_metrics) export(luz_callback_progress) export(luz_callback_train_valid) diff --git a/R/callbacks.R b/R/callbacks.R index 93e45d5e..1e3d739e 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -263,24 +263,51 @@ luz_callback_train_valid <- luz_callback( } ) +#' Early stopping callback +#' +#' Stops training when a monitored metric stops improving +#' +#' @param monitor A string in the format `_` where `` can be +#' 'train' or 'valid' and `` can be the abbreviation of any metric +#' that you are tracking during training. +#' @param min_delta Minimum improvement to reset the patience counter. +#' @param patience Number of epochs without improving until stoping training. +#' @param verbose If `TRUE` a message will be printed when early stopping occurs. +#' @param mode Specifies the direction that is considered an improvement. By default +#' 'min' is used. Can also be 'max' (higher is better) and 'zero' +#' (closer to zero is better). +#' @param baseline An initial value that will be used as the best seen value +#' in the begining. Model will stopm training if no better than baseline value +#' is found in the first `patience` epochs. +#' +#' @returns +#' A `luz_callback` that does early stopping. +#' +#' @examples +#' cb <- luz_callback_early_stopping() +#' +#' @family luz_callbacks +#' @export luz_callback_early_stopping <- luz_callback( name = "early_stopping_callback", initialize = function(monitor = "valid_loss", min_delta = 0, patience = 0, verbose = FALSE, mode="min", baseline=NULL) { self$monitor <- monitor - self$min_delta <- 0 - self$patience <- 0 + self$min_delta <- min_delta + self$patience <- patience self$verbose <- verbose self$mode <- mode self$baseline <- baseline if (!is.null(self$baseline)) private$current_best <- baseline + + self$patience_counter <- 0L }, on_fit_begin = function() { ctx$handlers <- append(ctx$handlers, list( early_stopping = function(err) { - ctx$call_all_callbacks("on_early_stopping") + ctx$call_callbacks("on_early_stopping") invisible(NULL) } )) @@ -306,7 +333,8 @@ luz_callback_early_stopping <- luz_callback( }, on_early_stopping = function() { - inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}")) + if (self$verbose) + rlang::inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}")) }, find_quantity = function() { o <- strsplit(self$monitor, "_")[[1]] @@ -320,8 +348,6 @@ luz_callback_early_stopping <- luz_callback( as.numeric(ctx$records$metrics[[set]][[qty]][[opt]]) } - browser() - if (length(out) != 1) rlang::abort(glue::glue("Expected monitored metric to be length 1, got {length(out)}")) @@ -335,7 +361,7 @@ luz_callback_early_stopping <- luz_callback( else if (self$mode == "zero") abs(x) < abs(y) - as.numeric(out) + as.array(out) } ) diff --git a/man/ctx.Rd b/man/ctx.Rd index e1dc2229..a0b15eae 100644 --- a/man/ctx.Rd +++ b/man/ctx.Rd @@ -36,6 +36,7 @@ could potentially modify these attributes or add new ones.\tabular{ll}{ \code{metrics} \tab \code{list()} of metric objects that are \code{update}d at every \code{on_train_batch_end()} or \code{on_valid_batch_end()}. See also \code{help("luz_callback_metrics")} \cr \code{records} \tab \code{list()} recording metric values for training and validation for each epoch. See also \code{help("luz_callback_metrics")} \cr \code{losses} \tab \code{list()} tracking losses over time. See also \code{help("luz_callback_metrics")} \cr + \code{handlers} \tab A named \code{list()} of handlers that is passed to \code{rlang::with_handlers()} during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. \cr } diff --git a/man/luz_callback.Rd b/man/luz_callback.Rd index fe3e61b3..e1f66efb 100644 --- a/man/luz_callback.Rd +++ b/man/luz_callback.Rd @@ -134,6 +134,7 @@ print_callback <- luz_callback( } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_metrics}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()} diff --git a/man/luz_callback_early_stopping.Rd b/man/luz_callback_early_stopping.Rd new file mode 100644 index 00000000..1c2caf70 --- /dev/null +++ b/man/luz_callback_early_stopping.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/callbacks.R +\name{luz_callback_early_stopping} +\alias{luz_callback_early_stopping} +\title{Early stopping callback} +\usage{ +luz_callback_early_stopping( + monitor = "valid_loss", + min_delta = 0, + patience = 0, + verbose = FALSE, + mode = "min", + baseline = NULL +) +} +\arguments{ +\item{monitor}{A string in the format \verb{_} where \verb{} can be +'train' or 'valid' and \verb{} can be the abbreviation of any metric +that you are tracking during training.} + +\item{min_delta}{Minimum improvement to reset the patience counter.} + +\item{patience}{Number of epochs without improving until stoping training.} + +\item{verbose}{If \code{TRUE} a message will be printed when early stopping occurs.} + +\item{mode}{Specifies the direction that is considered an improvement. By default +'min' is used. Can also be 'max' (higher is better) and 'zero' +(closer to zero is better).} + +\item{baseline}{An initial value that will be used as the best seen value +in the begining. Model will stopm training if no better than baseline value +is found in the first \code{patience} epochs.} +} +\value{ +A \code{luz_callback} that does early stopping. +} +\description{ +Stops training when a monitored metric stops improving +} +\examples{ +cb <- luz_callback_early_stopping() + +} +\seealso{ +Other luz_callbacks: +\code{\link{luz_callback_metrics}()}, +\code{\link{luz_callback_progress}()}, +\code{\link{luz_callback_train_valid}()}, +\code{\link{luz_callback}()} +} +\concept{luz_callbacks} diff --git a/man/luz_callback_metrics.Rd b/man/luz_callback_metrics.Rd index 4c6db3b8..f6956105 100644 --- a/man/luz_callback_metrics.Rd +++ b/man/luz_callback_metrics.Rd @@ -26,6 +26,7 @@ used by default in \code{\link[=fit.luz_module_generator]{fit.luz_module_generat } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback_train_valid}()}, \code{\link{luz_callback}()} diff --git a/man/luz_callback_progress.Rd b/man/luz_callback_progress.Rd index d211d40a..190d8d3f 100644 --- a/man/luz_callback_progress.Rd +++ b/man/luz_callback_progress.Rd @@ -17,6 +17,7 @@ Printing can be disabled by passing \code{verbose=FALSE} to \code{\link[=fit.luz } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_metrics}()}, \code{\link{luz_callback_train_valid}()}, \code{\link{luz_callback}()} diff --git a/man/luz_callback_train_valid.Rd b/man/luz_callback_train_valid.Rd index 825ed977..201fad7b 100644 --- a/man/luz_callback_train_valid.Rd +++ b/man/luz_callback_train_valid.Rd @@ -26,6 +26,7 @@ used by default in \code{\link[=fit.luz_module_generator]{fit.luz_module_generat } \seealso{ Other luz_callbacks: +\code{\link{luz_callback_early_stopping}()}, \code{\link{luz_callback_metrics}()}, \code{\link{luz_callback_progress}()}, \code{\link{luz_callback}()} diff --git a/man/rmd/ctx.Rmd b/man/rmd/ctx.Rmd index e824f3d9..bc21b403 100644 --- a/man/rmd/ctx.Rmd +++ b/man/rmd/ctx.Rmd @@ -49,5 +49,7 @@ The `ctx` object is used in luz to share information between the training loop a +------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | `losses` | `list()` tracking losses over time. See also `help("luz_callback_metrics")` | +------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| `handlers` | A named `list()` of handlers that is passed to `rlang::with_handlers()` during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. | ++------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ : Context attributes diff --git a/tests/testthat/_snaps/callbacks.md b/tests/testthat/_snaps/callbacks.md new file mode 100644 index 00000000..f6972586 --- /dev/null +++ b/tests/testthat/_snaps/callbacks.md @@ -0,0 +1,17 @@ +# early stopping + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, + verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( + monitor = "train_loss", patience = 1))) + }) + Message + Train metrics: Loss: 1.5301 + Message + Epoch 2/25 + Message + Train metrics: Loss: 1.654 + Message + Early stopping at epoch 2 of 25 + diff --git a/tests/testthat/test-callbacks.R b/tests/testthat/test-callbacks.R index d23a0e91..fa2f31fe 100644 --- a/tests/testthat/test-callbacks.R +++ b/tests/testthat/test-callbacks.R @@ -1,4 +1,7 @@ test_that("early stopping", { + torch::torch_manual_seed(1) + set.seed(1) + model <- get_model() dl <- get_dl() @@ -8,10 +11,14 @@ test_that("early stopping", { optimizer = optim_adam, ) - output <- mod %>% - set_hparams(input_size = 10, output_size = 1) %>% - fit(dl, verbose = TRUE, callbacks = list( - luz_callback_early_stopping(monitor = "train_loss", patience = 1) - )) + expect_snapshot({ + expect_message({ + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = TRUE, epochs = 25, callbacks = list( + luz_callback_early_stopping(monitor = "train_loss", patience = 1) + )) + }) + }) }) From f82d5fbbf8b0d64525ebde2840d256c9721c13d3 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 11 May 2021 20:04:54 -0300 Subject: [PATCH 3/7] add note --- R/callbacks.R | 4 ++++ man/luz_callback_early_stopping.Rd | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/R/callbacks.R b/R/callbacks.R index 1e3d739e..b60fdf10 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -280,6 +280,10 @@ luz_callback_train_valid <- luz_callback( #' in the begining. Model will stopm training if no better than baseline value #' is found in the first `patience` epochs. #' +#' @note +#' This callback adds a `on_early_stopping` callback that can be used to +#' call callbacks after as soon as the model stopped training. +#' #' @returns #' A `luz_callback` that does early stopping. #' diff --git a/man/luz_callback_early_stopping.Rd b/man/luz_callback_early_stopping.Rd index 1c2caf70..75e6d591 100644 --- a/man/luz_callback_early_stopping.Rd +++ b/man/luz_callback_early_stopping.Rd @@ -38,6 +38,10 @@ A \code{luz_callback} that does early stopping. \description{ Stops training when a monitored metric stops improving } +\note{ +This callback adds a \code{on_early_stopping} callback that can be used to +call callbacks after as soon as the model stopped training. +} \examples{ cb <- luz_callback_early_stopping() From 2d99a1e9413238d919745e3caec868ab3808787f Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 11 May 2021 21:42:26 -0300 Subject: [PATCH 4/7] remove `verbose` arg in favor of the global verbose passed to `fit`. --- R/callbacks.R | 11 ++++++----- man/luz_callback_early_stopping.Rd | 6 +++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/R/callbacks.R b/R/callbacks.R index b60fdf10..43cb9ff8 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -272,7 +272,6 @@ luz_callback_train_valid <- luz_callback( #' that you are tracking during training. #' @param min_delta Minimum improvement to reset the patience counter. #' @param patience Number of epochs without improving until stoping training. -#' @param verbose If `TRUE` a message will be printed when early stopping occurs. #' @param mode Specifies the direction that is considered an improvement. By default #' 'min' is used. Can also be 'max' (higher is better) and 'zero' #' (closer to zero is better). @@ -284,6 +283,10 @@ luz_callback_train_valid <- luz_callback( #' This callback adds a `on_early_stopping` callback that can be used to #' call callbacks after as soon as the model stopped training. #' +#' @note +#' If `verbose=TRUE` in [fit.luz_module_generator()] a message is printed when +#' early stopping. +#' #' @returns #' A `luz_callback` that does early stopping. #' @@ -295,11 +298,10 @@ luz_callback_train_valid <- luz_callback( luz_callback_early_stopping <- luz_callback( name = "early_stopping_callback", initialize = function(monitor = "valid_loss", min_delta = 0, patience = 0, - verbose = FALSE, mode="min", baseline=NULL) { + mode="min", baseline=NULL) { self$monitor <- monitor self$min_delta <- min_delta self$patience <- patience - self$verbose <- verbose self$mode <- mode self$baseline <- baseline @@ -337,8 +339,7 @@ luz_callback_early_stopping <- luz_callback( }, on_early_stopping = function() { - if (self$verbose) - rlang::inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}")) + inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}")) }, find_quantity = function() { o <- strsplit(self$monitor, "_")[[1]] diff --git a/man/luz_callback_early_stopping.Rd b/man/luz_callback_early_stopping.Rd index 75e6d591..ff331065 100644 --- a/man/luz_callback_early_stopping.Rd +++ b/man/luz_callback_early_stopping.Rd @@ -8,7 +8,6 @@ luz_callback_early_stopping( monitor = "valid_loss", min_delta = 0, patience = 0, - verbose = FALSE, mode = "min", baseline = NULL ) @@ -22,8 +21,6 @@ that you are tracking during training.} \item{patience}{Number of epochs without improving until stoping training.} -\item{verbose}{If \code{TRUE} a message will be printed when early stopping occurs.} - \item{mode}{Specifies the direction that is considered an improvement. By default 'min' is used. Can also be 'max' (higher is better) and 'zero' (closer to zero is better).} @@ -41,6 +38,9 @@ Stops training when a monitored metric stops improving \note{ This callback adds a \code{on_early_stopping} callback that can be used to call callbacks after as soon as the model stopped training. + +If \code{verbose=TRUE} in \code{\link[=fit.luz_module_generator]{fit.luz_module_generator()}} a message is printed when +early stopping. } \examples{ cb <- luz_callback_early_stopping() From 01fdf9b7df95a41c86f6e5d0f98384b6a9242099 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 11 May 2021 22:01:52 -0300 Subject: [PATCH 5/7] add a few more tests --- R/callbacks.R | 15 ++++++++------- tests/testthat/_snaps/callbacks.md | 29 +++++++++++++++++++++++++++-- tests/testthat/test-callbacks.R | 24 ++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/R/callbacks.R b/R/callbacks.R index 43cb9ff8..e5cdb669 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -53,7 +53,7 @@ default_callbacks <- function() { #' A `luz_callback` that can be passed to [fit.luz_module_generator()]. #' @family luz_callbacks #' @export -luz_callback <- function(name, ..., private = NULL, active = NULL, parent_env = parent.frame(), +luz_callback <- function(name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(), inherit = NULL) { make_class( name = name, @@ -306,7 +306,7 @@ luz_callback_early_stopping <- luz_callback( self$baseline <- baseline if (!is.null(self$baseline)) - private$current_best <- baseline + self$current_best <- baseline self$patience_counter <- 0L }, @@ -333,7 +333,7 @@ luz_callback_early_stopping <- luz_callback( self$patience_counter <- self$patience_counter + 1L } - if (self$patience_counter > self$patience) { + if (self$patience_counter >= self$patience) { rlang::signal("Early stopping", class = "early_stopping") } @@ -358,13 +358,14 @@ luz_callback_early_stopping <- luz_callback( out }, - compare = function(x, y) { + # returns TRUE when the new is better then previous acording to mode + compare = function(new, old) { out <- if (self$mode == "min") - x < y + (old - self$min_delta) > new else if (self$mode == "max") - x > y + (new - self$min_delta) > old else if (self$mode == "zero") - abs(x) < abs(y) + (abs(old) - self$min_delta) > abs(self$min_delta) as.array(out) } diff --git a/tests/testthat/_snaps/callbacks.md b/tests/testthat/_snaps/callbacks.md index f6972586..387b0903 100644 --- a/tests/testthat/_snaps/callbacks.md +++ b/tests/testthat/_snaps/callbacks.md @@ -8,10 +8,35 @@ }) Message Train metrics: Loss: 1.5301 + Message + Early stopping at epoch 1 of 25 + +--- + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, + verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( + monitor = "train_loss", patience = 5, baseline = 0.001))) + }) + Message + Train metrics: Loss: 1.4807 Message Epoch 2/25 Message - Train metrics: Loss: 1.654 + Train metrics: Loss: 1.3641 + Message + Epoch 3/25 + Message + Train metrics: Loss: 1.2073 + Message + Epoch 4/25 + Message + Train metrics: Loss: 1.2524 + Message + Epoch 5/25 + Message + Train metrics: Loss: 1.1891 Message - Early stopping at epoch 2 of 25 + Early stopping at epoch 5 of 25 diff --git a/tests/testthat/test-callbacks.R b/tests/testthat/test-callbacks.R index fa2f31fe..fbf0b409 100644 --- a/tests/testthat/test-callbacks.R +++ b/tests/testthat/test-callbacks.R @@ -21,4 +21,28 @@ test_that("early stopping", { }) }) + expect_snapshot({ + expect_message({ + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = TRUE, epochs = 25, callbacks = list( + luz_callback_early_stopping(monitor = "train_loss", patience = 5, + baseline = 0.001) + )) + }) + }) + + x <- 0 + output <- mod %>% + set_hparams(input_size = 10, output_size = 1) %>% + fit(dl, verbose = FALSE, epochs = 25, callbacks = list( + luz_callback_early_stopping(monitor = "train_loss", patience = 5, + baseline = 0.001), + luz_callback(on_early_stopping = function() { + x <<- 1 + })() + )) + + expect_equal(x, 1) }) + From c86620c4dc3dc4ad9be9c3ef49590850d59a9488 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 11 May 2021 22:11:25 -0300 Subject: [PATCH 6/7] accept for new version of testthat --- tests/testthat/_snaps/callbacks.new.md | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/testthat/_snaps/callbacks.new.md diff --git a/tests/testthat/_snaps/callbacks.new.md b/tests/testthat/_snaps/callbacks.new.md new file mode 100644 index 00000000..4de71b2f --- /dev/null +++ b/tests/testthat/_snaps/callbacks.new.md @@ -0,0 +1,32 @@ +# early stopping + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, + verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( + monitor = "train_loss", patience = 1))) + }) + Message + Train metrics: Loss: 1.5301 + Early stopping at epoch 1 of 25 + +--- + + Code + expect_message({ + output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, + verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( + monitor = "train_loss", patience = 5, baseline = 0.001))) + }) + Message + Train metrics: Loss: 1.4807 + Epoch 2/25 + Train metrics: Loss: 1.3641 + Epoch 3/25 + Train metrics: Loss: 1.2073 + Epoch 4/25 + Train metrics: Loss: 1.2524 + Epoch 5/25 + Train metrics: Loss: 1.1891 + Early stopping at epoch 5 of 25 + From 516396b3d05a7a4ad4ad379fa6dea6f04c8a2f0d Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Tue, 11 May 2021 22:16:50 -0300 Subject: [PATCH 7/7] call accept correctly :( --- tests/testthat/_snaps/callbacks.md | 10 -------- tests/testthat/_snaps/callbacks.new.md | 32 -------------------------- 2 files changed, 42 deletions(-) delete mode 100644 tests/testthat/_snaps/callbacks.new.md diff --git a/tests/testthat/_snaps/callbacks.md b/tests/testthat/_snaps/callbacks.md index 387b0903..4de71b2f 100644 --- a/tests/testthat/_snaps/callbacks.md +++ b/tests/testthat/_snaps/callbacks.md @@ -8,7 +8,6 @@ }) Message Train metrics: Loss: 1.5301 - Message Early stopping at epoch 1 of 25 --- @@ -21,22 +20,13 @@ }) Message Train metrics: Loss: 1.4807 - Message Epoch 2/25 - Message Train metrics: Loss: 1.3641 - Message Epoch 3/25 - Message Train metrics: Loss: 1.2073 - Message Epoch 4/25 - Message Train metrics: Loss: 1.2524 - Message Epoch 5/25 - Message Train metrics: Loss: 1.1891 - Message Early stopping at epoch 5 of 25 diff --git a/tests/testthat/_snaps/callbacks.new.md b/tests/testthat/_snaps/callbacks.new.md deleted file mode 100644 index 4de71b2f..00000000 --- a/tests/testthat/_snaps/callbacks.new.md +++ /dev/null @@ -1,32 +0,0 @@ -# early stopping - - Code - expect_message({ - output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, - verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( - monitor = "train_loss", patience = 1))) - }) - Message - Train metrics: Loss: 1.5301 - Early stopping at epoch 1 of 25 - ---- - - Code - expect_message({ - output <- mod %>% set_hparams(input_size = 10, output_size = 1) %>% fit(dl, - verbose = TRUE, epochs = 25, callbacks = list(luz_callback_early_stopping( - monitor = "train_loss", patience = 5, baseline = 0.001))) - }) - Message - Train metrics: Loss: 1.4807 - Epoch 2/25 - Train metrics: Loss: 1.3641 - Epoch 3/25 - Train metrics: Loss: 1.2073 - Epoch 4/25 - Train metrics: Loss: 1.2524 - Epoch 5/25 - Train metrics: Loss: 1.1891 - Early stopping at epoch 5 of 25 -