Skip to content

Commit

Permalink
Merge pull request #7 from mlverse/feature/early-stopping
Browse files Browse the repository at this point in the history
Feature/early stopping
  • Loading branch information
dfalbel authored May 12, 2021
2 parents bf744cb + 516396b commit 35d0d24
Show file tree
Hide file tree
Showing 13 changed files with 304 additions and 39 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export("%>%")
export(accelerator)
export(fit)
export(luz_callback)
export(luz_callback_early_stopping)
export(luz_callback_metrics)
export(luz_callback_progress)
export(luz_callback_train_valid)
Expand Down
110 changes: 109 additions & 1 deletion R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ default_callbacks <- function() {
#' A `luz_callback` that can be passed to [fit.luz_module_generator()].
#' @family luz_callbacks
#' @export
luz_callback <- function(name, ..., private = NULL, active = NULL, parent_env = parent.frame(),
luz_callback <- function(name = NULL, ..., private = NULL, active = NULL, parent_env = parent.frame(),
inherit = NULL) {
make_class(
name = name,
Expand Down Expand Up @@ -263,4 +263,112 @@ luz_callback_train_valid <- luz_callback(
}
)

#' Early stopping callback
#'
#' Stops training when a monitored metric stops improving
#'
#' @param monitor A string in the format `<set>_<metric>` where `<set>` can be
#' 'train' or 'valid' and `<metric>` can be the abbreviation of any metric
#' that you are tracking during training.
#' @param min_delta Minimum improvement to reset the patience counter.
#' @param patience Number of epochs without improving until stoping training.
#' @param mode Specifies the direction that is considered an improvement. By default
#' 'min' is used. Can also be 'max' (higher is better) and 'zero'
#' (closer to zero is better).
#' @param baseline An initial value that will be used as the best seen value
#' in the begining. Model will stopm training if no better than baseline value
#' is found in the first `patience` epochs.
#'
#' @note
#' This callback adds a `on_early_stopping` callback that can be used to
#' call callbacks after as soon as the model stopped training.
#'
#' @note
#' If `verbose=TRUE` in [fit.luz_module_generator()] a message is printed when
#' early stopping.
#'
#' @returns
#' A `luz_callback` that does early stopping.
#'
#' @examples
#' cb <- luz_callback_early_stopping()
#'
#' @family luz_callbacks
#' @export
luz_callback_early_stopping <- luz_callback(
name = "early_stopping_callback",
initialize = function(monitor = "valid_loss", min_delta = 0, patience = 0,
mode="min", baseline=NULL) {
self$monitor <- monitor
self$min_delta <- min_delta
self$patience <- patience
self$mode <- mode
self$baseline <- baseline

if (!is.null(self$baseline))
self$current_best <- baseline

self$patience_counter <- 0L
},
on_fit_begin = function() {
ctx$handlers <- append(ctx$handlers, list(
early_stopping = function(err) {
ctx$call_callbacks("on_early_stopping")
invisible(NULL)
}
))
},
on_epoch_end = function() {

qty <- self$find_quantity()
if (is.null(self$current_best))
self$current_best <- qty

if (self$compare(qty, self$current_best)) {
# means that new qty is better then previous
self$current_best <- qty
self$patience_counter <- 0L
} else {
# mean that qty did not improve
self$patience_counter <- self$patience_counter + 1L
}

if (self$patience_counter >= self$patience) {
rlang::signal("Early stopping", class = "early_stopping")
}

},
on_early_stopping = function() {
inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}"))
},
find_quantity = function() {
o <- strsplit(self$monitor, "_")[[1]]
set <- o[[1]]
qty <- o[[2]]
opt <- if (length(o) >= 3) o[[3]] else "opt"

out <- if (qty == "loss") {
as.numeric(utils::tail(ctx$losses[[set]], 1)[[1]][[opt]])
} else {
as.numeric(ctx$records$metrics[[set]][[qty]][[opt]])
}

if (length(out) != 1)
rlang::abort(glue::glue("Expected monitored metric to be length 1, got {length(out)}"))

out
},
# returns TRUE when the new is better then previous acording to mode
compare = function(new, old) {
out <- if (self$mode == "min")
(old - self$min_delta) > new
else if (self$mode == "max")
(new - self$min_delta) > old
else if (self$mode == "zero")
(abs(old) - self$min_delta) > abs(self$min_delta)

as.array(out)
}
)


87 changes: 49 additions & 38 deletions R/module.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,48 +193,53 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
call_all_callbacks(ctx$callbacks, name)
}

ctx$call_callbacks("on_fit_begin")
ctx$handlers <- list()

for (epoch in seq_len(ctx$epochs)) {
ctx$epoch <- epoch
ctx$iter <- 0L
ctx$call_callbacks("on_epoch_begin")
ctx$call_callbacks("on_fit_begin")
rlang::with_handlers(
!!! ctx$handlers,
.expr = {
for (epoch in seq_len(ctx$epochs)) {
ctx$epoch <- epoch
ctx$iter <- 0L
ctx$call_callbacks("on_epoch_begin")

ctx$call_callbacks("on_train_begin")
ctx$call_callbacks("on_train_begin")

coro::loop(for (batch in ctx$data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L
coro::loop(for (batch in ctx$data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L

ctx$call_callbacks("on_train_batch_begin")
step()
ctx$call_callbacks("on_train_batch_end")
})
ctx$call_callbacks("on_train_batch_begin")
step()
ctx$call_callbacks("on_train_batch_end")
})

ctx$call_callbacks("on_train_end")
ctx$call_callbacks("on_train_end")

if (!is.null(ctx$valid_data)) {
if (!is.null(ctx$valid_data)) {

ctx$call_callbacks("on_valid_begin")
ctx$call_callbacks("on_valid_begin")

ctx$iter <- 0L
torch::with_no_grad({
coro::loop(for (batch in ctx$valid_data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L
ctx$iter <- 0L
torch::with_no_grad({
coro::loop(for (batch in ctx$valid_data) {
bind_batch_to_ctx(ctx, batch)
ctx$iter <- ctx$iter + 1L

ctx$call_callbacks("on_valid_batch_begin")
step()
ctx$call_callbacks("on_valid_batch_end")
})
})
ctx$call_callbacks("on_valid_batch_begin")
step()
ctx$call_callbacks("on_valid_batch_end")
})
})

ctx$call_callbacks("on_valid_end")
ctx$call_callbacks("on_valid_end")

}
}

ctx$call_callbacks("on_epoch_end")
}
ctx$call_callbacks("on_epoch_end")
}
})

ctx$call_callbacks("on_fit_end")
structure(
Expand All @@ -251,7 +256,7 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
#' @importFrom stats predict
#' @export
predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(),
accelerator = NULL) {
accelerator = NULL) {

ctx <- object$ctx

Expand All @@ -274,20 +279,26 @@ predict.luz_module_fitted <- function(object, newdata, ..., callbacks = list(),
else
stack <- pars$stack

ctx$handlers <- list()
ctx$output <- list()
ctx$callbacks <- initialize_callbacks(callbacks, ctx)

predict_fn <- if (is.null(ctx$model$predict)) ctx$model else ctx$model$predict

torch::with_no_grad({
ctx$call_callbacks("on_predict_begin")
coro::loop(for(batch in data) {
ctx$batch <- batch
ctx$input <- batch[[1]]
ctx$call_callbacks("on_predict_batch_begin")
ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input))
ctx$call_callbacks("on_predict_batch_end")
})
rlang::with_handlers(
!!! ctx$handlers,
.expr = {
coro::loop(for(batch in data) {
ctx$batch <- batch
ctx$input <- batch[[1]]
ctx$call_callbacks("on_predict_batch_begin")
ctx$output[[length(ctx$output) + 1]] <- do.call(predict_fn, list(ctx$input))
ctx$call_callbacks("on_predict_batch_end")
})
}
)
ctx$call_callbacks("on_predict_end")
})

Expand Down
2 changes: 2 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,5 @@ make_class <- function(name, ..., private, active, inherit, parent_env, .init_fu
attr(f, "r6_class") <- r6_class
f
}


1 change: 1 addition & 0 deletions man/ctx.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions man/luz_callback_early_stopping.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_metrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_progress.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/luz_callback_train_valid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/rmd/ctx.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,7 @@ The `ctx` object is used in luz to share information between the training loop a
+------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `losses` | `list()` tracking losses over time. See also `help("luz_callback_metrics")` |
+------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `handlers` | A named `list()` of handlers that is passed to `rlang::with_handlers()` during the training loop and can be used to handle errors or conditions that might be raised by other callbacks. |
+------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

: Context attributes
Loading

0 comments on commit 35d0d24

Please sign in to comment.