Skip to content

Commit

Permalink
Merge pull request #55 from mattwarkentin/master
Browse files Browse the repository at this point in the history
Provide support for minimum and maximum number of epochs
  • Loading branch information
dfalbel authored Jul 29, 2021
2 parents 2932183 + 25eda79 commit 6e0bb77
Showing 6 changed files with 72 additions and 43 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@
* Fixed bug in CSV logger callback that was saving the logs as a space delimited file (#52, @mattwarkentin).
* Fixed bug in the length of the progress bar for the validation dataset (#52, @mattwarkentin).
* `ctx$data` now refers to the current in use `data` instead of always refering to `ctx$train_data`. (#54)
* Allow users to provide the minimum and maximum number of epochs when calling `fit.luz_module_generator()`. Removed `ctx$epochs` from context object and replaced it with `ctx$min_epochs` and `ctx$max_epochs` (#53, @mattwarkentin)
* Early stopping will now only occur if the minimum number of training epochs has been met (#53, @mattwarkentin)

# luz 0.1.0

11 changes: 7 additions & 4 deletions R/callbacks.R
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@ luz_callback_progress <- luz_callback(
inform(sprintf(
"Epoch %d/%d",
as.integer(ctx$epoch),
as.integer(ctx$epochs)
as.integer(ctx$max_epochs)
))
},
on_train_begin = function() {
@@ -356,7 +356,7 @@ monitor_metrics <- luz_callback(
#'
#' @note
#' This callback adds a `on_early_stopping` callback that can be used to
#' call callbacks after as soon as the model stopped training.
#' call callbacks as soon as the model stops training.
#'
#' @note
#' If `verbose=TRUE` in [fit.luz_module_generator()] a message is printed when
@@ -409,13 +409,16 @@ luz_callback_early_stopping <- luz_callback(
self$patience_counter <- self$patience_counter + 1L
}

if (self$patience_counter >= self$patience) {
if (self$patience_counter >= self$patience &
ctx$epoch >= ctx$min_epochs) {
rlang::signal("Early stopping", class = "early_stopping")
}

},
on_early_stopping = function() {
inform(glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$epochs}"))
inform(
glue::glue("Early stopping at epoch {ctx$epoch} of {ctx$max_epochs}")
)
}
)

58 changes: 37 additions & 21 deletions R/module.R
Original file line number Diff line number Diff line change
@@ -117,27 +117,33 @@ get_opt_hparams <- function(module) {
#' @param object An `nn_module` that has been [setup()].
#'
#' @param data (dataloader) A dataloader created with [torch::dataloader()] used
#' for training the model. The dataloader must return a list with at most 2 items.
#' The first item will be used as input for the module and the second will be used
#' as target for the loss function.
#' for training the model. The dataloader must return a list with at most 2
#' items. The first item will be used as input for the module and the second
#' will be used as target for the loss function.
#'
#' @param epochs (int) The number of epochs for training the model.
#' @param epochs (int) The maximum number of epochs for training the model.
#' If a single value is provided, this is taken to be the `max_epochs` and
#' `min_epochs` is set to 0. If a vector of two numbers is provided, the
#' first value is `min_epochs` and the second value is `max_epochs`.
#' The minimum and maximum number of epochs are included in the context
#' object as `ctx$min_epochs` and `ctx$max_epochs`, respectively.
#'
#' @param callbacks (list, optional) A list of callbacks defined with [luz_callback()] that
#' will be called during the training procedure. The callbacks [luz_callback_metrics()],
#' [luz_callback_progress()] and [luz_callback_train_valid()] are always added by default.
#' @param callbacks (list, optional) A list of callbacks defined with
#' [luz_callback()] that will be called during the training procedure. The
#' callbacks [luz_callback_metrics()], [luz_callback_progress()] and
#' [luz_callback_train_valid()] are always added by default.
#'
#' @param valid_data (dataloader, optional) A dataloader created with [torch::dataloader()]
#' that will be used during the validation procedure.
#' @param valid_data (dataloader, optional) A dataloader created with
#' [torch::dataloader()] that will be used during the validation procedure.
#'
#' @param accelerator (accelerator, optional) An optional [accelerator()] object used
#' to configure device placement of the components like [nn_module]s, optimizers
#' and batches of data.
#' @param accelerator (accelerator, optional) An optional [accelerator()] object
#' used to configure device placement of the components like [nn_module]s,
#' optimizers and batches of data.
#'
#' @param verbose (logical, optional) An optional boolean value indicating if the
#' fitting procedure should emmit output to the console during training. By default,
#' it will produce output if [interactive()] is `TRUE`, otherwise it won't print
#' to the console.
#' @param verbose (logical, optional) An optional boolean value indicating if
#' the fitting procedure should emmit output to the console during training.
#' By default, it will produce output if [interactive()] is `TRUE`, otherwise
#' it won't print to the console.
#'
#' @param ... Currently unused,
#'
@@ -147,9 +153,16 @@ get_opt_hparams <- function(module) {
#'
#' @importFrom generics fit
#' @export
fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL,
valid_data = NULL, accelerator = NULL,
verbose = NULL, ...) {
fit.luz_module_generator <- function(
object,
data,
epochs = 10,
callbacks = NULL,
valid_data = NULL,
accelerator = NULL,
verbose = NULL,
...
) {

module <- object
ellipsis::check_dots_empty()
@@ -190,7 +203,10 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
ctx$train_data <- data
ctx$valid_data <- valid_data

ctx$epochs <- epochs
if (length(epochs) == 1) epochs <- c(0, epochs)
ctx$min_epochs <- epochs[[1]]
ctx$max_epochs <- epochs[[2]]

callbacks <- append(default_callbacks(), callbacks)
ctx$callbacks <- initialize_callbacks(callbacks, ctx)

@@ -209,7 +225,7 @@ fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL
rlang::with_handlers(
!!! ctx$handlers,
.expr = {
for (epoch in seq_len(ctx$epochs)) {
for (epoch in seq_len(ctx$max_epochs)) {
ctx$epoch <- epoch
ctx$iter <- 0L

38 changes: 22 additions & 16 deletions man/fit.luz_module_generator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/luz_callback_early_stopping.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/rmd/ctx.Rmd
Original file line number Diff line number Diff line change
@@ -17,7 +17,9 @@ The `ctx` object is used in luz to share information between the training loop a
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `valid_data` | Dataloader passed to the `valid_data` argument in `fit`. Modified to yield data in the selected device. |
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `epochs` | Total number of epochs the model will be trained on. |
| `min_epochs` | Minimum number of epochs the model will be trained for. |
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `max_epochs` | Maximum number of epochs the model will be trained for. |
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| `epoch` | Current training epoch. |
+------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

0 comments on commit 6e0bb77

Please sign in to comment.