From 0877779fd9c0afb817f30ce2e0f31164fb69657f Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 15 Jan 2025 19:05:06 +0100 Subject: [PATCH 1/4] ope): Brief description fix(ignite): making state saveable, param_groups modifieable It was impossible to save the state because of: * https://github.com/mlverse/torch/issues/1233 * undefined tensors were part of the state such as max_exp_avg_sq for adam with amsgrad = FALSE. We now keep them as 0-sized tensors as undefined tensors are not serializeable. (The reason we keep them at all is that it simplifies the saving and loading of state dicts easier) This PR also improves the tests by removing an unnecessary call to `torch_manual_seed()` that made the tests deterministic --- R/ignite.R | 15 +++++-- src/lantern/src/Ignite.cpp | 29 ++++++++------ tests/testthat/helper-ignite.R | 4 +- tests/testthat/test-ignite.R | 71 ++++++++++++++++++++++++++++++++-- 4 files changed, 97 insertions(+), 22 deletions(-) diff --git a/R/ignite.R b/R/ignite.R index 2eb7eb886d..6755b77d6b 100644 --- a/R/ignite.R +++ b/R/ignite.R @@ -117,7 +117,8 @@ OptimizerIgnite <- R6::R6Class( #' The parameter groups of the optimizer. param_groups = function(rhs) { if (!missing(rhs)) { - prev_param_groups <- self$state_dict()$param_groups + prev_param_groups <- self$param_groups + all_params = unlist(lapply(prev_param_groups, function(x) x$params)) if (!is.list(rhs) && length(rhs) == length(prev_param_groups)) { value_error("Parameter groups must be a list of the same length as the number of parameter groups.") } @@ -128,8 +129,16 @@ OptimizerIgnite <- R6::R6Class( value_error("Parameter groups must have names {paste0(names(prev_param_group), collapse = ', ')} but got {paste0(names(new_param_group), collapse = ', ')}.") } - if (!identical(prev_param_group$params, new_param_group$params)) { - value_error("Cannot change the indices of the parameter group, use `$add_param_group()` to add a new parameter group.") + param_cmp_value = if (is.integer(new_param_group$params)) { + all_params[new_param_group$params] + } else { + new_param_group$params + } + + if (!identical(prev_param_group$params, param_cmp_value)) { + print(prev_param_group$params) + print(new_param_group$params) + value_error("Cannot change the parameter groups, use `$add_param_group()` to add a new parameter group.") } private$.set_param_group_options(self$ptr, rhs) diff --git a/src/lantern/src/Ignite.cpp b/src/lantern/src/Ignite.cpp index 0726a4f44c..b1c0384aaa 100644 --- a/src/lantern/src/Ignite.cpp +++ b/src/lantern/src/Ignite.cpp @@ -1,5 +1,6 @@ #include #include +#include #define LANTERN_BUILD #include #include "lantern/lantern.h" @@ -113,7 +114,7 @@ void* _ignite_adagrad_get_states(void* optim) { auto base_state = state_it->second.get(); auto adagrad_state = static_cast(base_state); tensors.push_back(adagrad_state->sum().clone()); - tensors.push_back(torch::scalar_tensor(adagrad_state->step(), torch::kLong)); + tensors.push_back(torch::tensor({adagrad_state->step()}, torch::kLong)); } } } @@ -201,9 +202,9 @@ void* _ignite_adam_get_states(void* optim) { if (adam_state->max_exp_avg_sq().defined()) { tensors.push_back(adam_state->max_exp_avg_sq().clone()); } else { - tensors.push_back(torch::Tensor()); + tensors.push_back(torch::empty(0, torch::kFloat32)); } - tensors.push_back(torch::scalar_tensor(adam_state->step(), torch::kLong)); + tensors.push_back(torch::tensor({adam_state->step()}, torch::kLong)); } } } @@ -226,7 +227,7 @@ void _ignite_adam_set_states(void* optim, void* params,void* states_) { auto* current_state = static_cast(state_it->second.get()); current_state->exp_avg(states[i]); current_state->exp_avg_sq(states[i + 1]); - if (states[i + 2].defined()) { + if (states[i + 2].numel() != 0) { current_state->max_exp_avg_sq(states[i + 2]); } auto step = states[i + 3]; @@ -324,9 +325,9 @@ void* _ignite_adamw_get_states(void* optim) { if (adamw_state->max_exp_avg_sq().defined()) { tensors.push_back(adamw_state->max_exp_avg_sq().clone()); } else { - tensors.push_back(torch::Tensor()); + tensors.push_back(torch::empty(0, torch::kFloat32)); } - tensors.push_back(torch::scalar_tensor(adamw_state->step(), torch::kLong)); + tensors.push_back(torch::tensor({adamw_state->step()}, torch::kLong)); } } } @@ -422,15 +423,15 @@ void* _ignite_rmsprop_get_states(void* optim) { if (rmsprop_state->grad_avg().defined()) { tensors.push_back(rmsprop_state->grad_avg().clone()); } else { - tensors.push_back(torch::Tensor()); + tensors.push_back(torch::empty(0, torch::kFloat32)); } tensors.push_back(rmsprop_state->square_avg().clone()); if (rmsprop_state->momentum_buffer().defined()) { tensors.push_back(rmsprop_state->momentum_buffer().clone()); } else { - tensors.push_back(torch::Tensor()); + tensors.push_back(torch::empty(0, torch::kFloat32)); } - tensors.push_back(torch::scalar_tensor(rmsprop_state->step(), torch::kLong)); + tensors.push_back(torch::tensor({rmsprop_state->step()}, torch::kLong)); } } } @@ -451,9 +452,13 @@ void _ignite_rmsprop_set_states(void* optim, void* params, void* states_) { state_it = opt->state().find(param.unsafeGetTensorImpl()); } auto* current_state = static_cast(state_it->second.get()); - current_state->grad_avg(states[i]); + if (states[i].numel() != 0) { + current_state->grad_avg(states[i]); + } current_state->square_avg(states[i + 1]); - current_state->momentum_buffer(states[i + 2]); + if (states[i + 2].numel() != 0) { + current_state->momentum_buffer(states[i + 2]); + } auto step = states[i + 3]; current_state->step(step.item()); i += 4; @@ -519,7 +524,7 @@ void* _ignite_sgd_get_states(void* optim) { if (sgd_state->momentum_buffer().defined()) { tensors.push_back(sgd_state->momentum_buffer().clone()); } else { - tensors.push_back(torch::Tensor()); + tensors.push_back(torch::empty(0, torch::kFloat32)); } } } diff --git a/tests/testthat/helper-ignite.R b/tests/testthat/helper-ignite.R index 6bea5678af..069e896b2c 100644 --- a/tests/testthat/helper-ignite.R +++ b/tests/testthat/helper-ignite.R @@ -75,14 +75,12 @@ expect_state_dict_works <- function(optimizer_fn, ...) { } replicate(2, s()) if (load) { - o$load_state_dict(o$state_dict()) + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) } replicate(2, s()) return(n$parameters) } - torch_manual_seed(123) w1 <- f(load = TRUE) - torch_manual_seed(123) w2 <- f(load = FALSE) expect_equal(w1, w2) } diff --git a/tests/testthat/test-ignite.R b/tests/testthat/test-ignite.R index 92f22eaf0b..85b46321f2 100644 --- a/tests/testthat/test-ignite.R +++ b/tests/testthat/test-ignite.R @@ -30,10 +30,10 @@ test_that("un-optimized parameters and state dict", { states = sd$state expect_equal(names(states), "1") # all parameters are included in the state dict even when they don't have a state. - expect_false(cpp_tensor_is_undefined(states[[1]]$exp_avg)) - expect_false(cpp_tensor_is_undefined(states[[1]]$exp_avg_sq)) - expect_true(cpp_tensor_is_undefined(states[[1]]$max_exp_avg_sq)) - expect_false(cpp_tensor_is_undefined(states[[1]]$step)) + expect_false(is.null(states[[1]]$exp_avg)) + expect_false(is.null(states[[1]]$exp_avg_sq)) + expect_false(is.null(states[[1]]$max_exp_avg_sq)) + expect_false(is.null(states[[1]]$step)) opt$load_state_dict(sd) x1 = unlist(states) x2 = unlist(opt$state_dict()$state) @@ -58,6 +58,12 @@ test_that("adam", { expect_ignite_can_change_param_groups(optim_ignite_adam) expect_ignite_can_add_param_group(optim_ignite_adam) do.call(expect_state_dict_works, c(list(optim_ignite_adam), defaults)) + # can save adam even when one of the tensors in the state is undefined in C++ + defaults$amsgrad <- FALSE + o <- do.call(make_ignite_adam, defaults) + prev <- o$state_dict() + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) + expect_equal(prev, o$state_dict()) }) test_that("adamw", { @@ -73,6 +79,22 @@ test_that("adamw", { expect_ignite_can_change_param_groups(optim_ignite_adamw) expect_ignite_can_add_param_group(optim_ignite_adamw) do.call(expect_state_dict_works, c(list(optim_ignite_adamw), defaults)) + + # can save adamw even when one of the tensors in the state is undefined in C++ + defaults$amsgrad <- FALSE + o <- do.call(make_ignite_adamw, defaults) + prev <- o$state_dict() + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) + expect_equal(prev, o$state_dict()) + o$param_groups[[1L]]$amsgrad <- TRUE + + # now we check whether an uninitialized state parameter can be created later + step <- function() { + ((o$param_groups[[1]][[1]][[1]] * o$param_groups[[1]][[1]][[2]] * torch_tensor(1) - torch_tensor(2))^2)$backward() + } + step() + step() + (o$param_groups[[1L]]$amsgrad) }) test_that("sgd", { @@ -87,6 +109,20 @@ test_that("sgd", { expect_ignite_can_change_param_groups(optim_ignite_sgd, lr = 0.1) expect_ignite_can_add_param_group(optim_ignite_sgd) do.call(expect_state_dict_works, c(list(optim_ignite_sgd), defaults)) + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) + + # saving of state dict + defaults$momentum = FALSE + o <- do.call(make_ignite_sgd, defaults) + prev <- o$state_dict() + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) + expect_equal(prev, o$state_dict()) + + defaults$momentum = TRUE + o <- do.call(make_ignite_sgd, defaults) + prev <- o$state_dict() + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) + expect_equal(prev, o$state_dict()) }) test_that("rmsprop", { @@ -102,6 +138,11 @@ test_that("rmsprop", { expect_ignite_can_change_param_groups(optim_ignite_rmsprop) expect_ignite_can_add_param_group(optim_ignite_rmsprop) do.call(expect_state_dict_works, c(list(optim_ignite_rmsprop), defaults)) + + o <- do.call(make_ignite_rmsprop, defaults) + prev <- o$state_dict() + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) + expect_equal(prev, o$state_dict()) }) test_that("adagrad", { @@ -117,6 +158,11 @@ test_that("adagrad", { expect_ignite_can_change_param_groups(optim_ignite_adagrad) expect_ignite_can_add_param_group(optim_ignite_adagrad) do.call(expect_state_dict_works, c(list(optim_ignite_adagrad), defaults)) + + o <- do.call(make_ignite_adagrad, defaults) + prev <- o$state_dict() + o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) + expect_equal(prev, o$state_dict()) }) test_that("base class: can initialize optimizer with different options per param group", { @@ -160,6 +206,23 @@ test_that("base class: params must have length > 1", { expect_error(optim_ignite_adamw(list()), "must have length") }) +test_that("can change values of param groups in optimizer", { + o <- make_ignite_adamw(amsgrad = TRUE) + o$param_groups[[1]]$amsgrad <- FALSE + expect_false(o$param_groups[[1]]$amsgrad) +}) + +test_that("base class: can change values of param_groups", { + o = optim_ignite_adamw(list(torch_tensor(1, requires_grad = TRUE)), lr = 0.1) + o$param_groups[[1]]$lr = 1 + expect_equal(o$param_groups[[1]]$lr, 1) + o$param_groups[[1]]$amsgrad = FALSE + expect_true(!o$param_groups[[1]]$amsgrad) + o$param_groups[[1]]$amsgrad = TRUE + expect_false(!o$param_groups[[1]]$amsgrad) +}) + + test_that("base class: error handling when loading state dict", { o = make_ignite_adamw() expect_error(o$load_state_dict(list()), "must be a list with elements") From 2590b2306f5ded59523370c4765f0f9fbd2b98db Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 15 Jan 2025 19:19:58 +0100 Subject: [PATCH 2/4] cleanup previous commit --- tests/testthat/test-ignite.R | 38 +++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/tests/testthat/test-ignite.R b/tests/testthat/test-ignite.R index 85b46321f2..ae04b6d825 100644 --- a/tests/testthat/test-ignite.R +++ b/tests/testthat/test-ignite.R @@ -86,15 +86,6 @@ test_that("adamw", { prev <- o$state_dict() o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) expect_equal(prev, o$state_dict()) - o$param_groups[[1L]]$amsgrad <- TRUE - - # now we check whether an uninitialized state parameter can be created later - step <- function() { - ((o$param_groups[[1]][[1]][[1]] * o$param_groups[[1]][[1]][[2]] * torch_tensor(1) - torch_tensor(2))^2)$backward() - } - step() - step() - (o$param_groups[[1L]]$amsgrad) }) test_that("sgd", { @@ -206,12 +197,6 @@ test_that("base class: params must have length > 1", { expect_error(optim_ignite_adamw(list()), "must have length") }) -test_that("can change values of param groups in optimizer", { - o <- make_ignite_adamw(amsgrad = TRUE) - o$param_groups[[1]]$amsgrad <- FALSE - expect_false(o$param_groups[[1]]$amsgrad) -}) - test_that("base class: can change values of param_groups", { o = optim_ignite_adamw(list(torch_tensor(1, requires_grad = TRUE)), lr = 0.1) o$param_groups[[1]]$lr = 1 @@ -237,7 +222,28 @@ test_that("base class: error handling when loading state dict", { expect_error(o$load_state_dict(sd3), "but got params, weight_decay") }) -test_that("deep cloning not possible", { +test_that("base class: deep cloning not possible", { o = make_ignite_adamw(steps = 0) expect_error(o$clone(deep = TRUE), "OptimizerIgnite cannot be deep cloned") }) + +test_that("base class: changing the learning rate has an effect", { + n1 = nn_linear(1, 1) + n2 = n1$clone(deep = TRUE) + o1 = optim_sgd(n1$parameters, lr = 0.1) + o2 = optim_sgd(n2$parameters, lr = 0.1) + + s = function(n, o) { + o$zero_grad() + ((n(torch_tensor(1)) - torch_tensor(1))^2)$backward() + o$step() + } + + s(n1, o1) + s(n2, o2) + expect_true(torch_equal(n1$parameters[[1]], n2$parameters[[1]]) && torch_equal(n1$parameters[[2]], n2$parameters[[2]])) + o1$param_groups[[1]]$lr = 0.2 + s(n1, o1) + s(n2, o2) + expect_false(torch_equal(n1$parameters[[1]], n2$parameters[[1]]) && torch_equal(n1$parameters[[2]], n2$parameters[[2]])) +}) From b54cf6e881ebfa5f22b4f0149c4ca3ab2f0ec359 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 16 Jan 2025 07:46:55 +0100 Subject: [PATCH 3/4] trigger ci --- R/ignite.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/ignite.R b/R/ignite.R index 6755b77d6b..fe14243d9e 100644 --- a/R/ignite.R +++ b/R/ignite.R @@ -376,7 +376,6 @@ is_permutation <- function(vec1, vec2) { if (length(vec1) != length(vec2)) { return(FALSE) } - # Check if sorted elements are the same identical(sort(vec1), sort(vec2)) } From f86edb2973b9fb685b8860daa600c2498d6b1843 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Thu, 16 Jan 2025 09:37:37 +0100 Subject: [PATCH 4/4] tests: fix ignite sgd test --- tests/testthat/test-ignite.R | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/testthat/test-ignite.R b/tests/testthat/test-ignite.R index ae04b6d825..2d7eccb7d9 100644 --- a/tests/testthat/test-ignite.R +++ b/tests/testthat/test-ignite.R @@ -103,13 +103,6 @@ test_that("sgd", { o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) # saving of state dict - defaults$momentum = FALSE - o <- do.call(make_ignite_sgd, defaults) - prev <- o$state_dict() - o$load_state_dict(torch_load(torch_serialize(o$state_dict()))) - expect_equal(prev, o$state_dict()) - - defaults$momentum = TRUE o <- do.call(make_ignite_sgd, defaults) prev <- o$state_dict() o$load_state_dict(torch_load(torch_serialize(o$state_dict())))