From b2a2f4c09e555eca2a71198240970f027c041ba8 Mon Sep 17 00:00:00 2001 From: Philipp Haarmeyer Date: Fri, 26 Jul 2024 13:31:02 +0200 Subject: [PATCH] Register tunable {tabnet} parameters correctly --- R/parsnip.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/R/parsnip.R b/R/parsnip.R index e7e1ed9..de1d9a3 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -94,7 +94,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "decision_width", original = "decision_width", - func = list(pkg = "dials", fun = "decision_width"), + func = list(pkg = "tabnet", fun = "decision_width"), has_submodel = FALSE ) @@ -103,7 +103,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "attention_width", original = "attention_width", - func = list(pkg = "dials", fun = "attention_width"), + func = list(pkg = "tabnet", fun = "attention_width"), has_submodel = FALSE ) @@ -112,7 +112,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_steps", original = "num_steps", - func = list(pkg = "dials", fun = "num_steps"), + func = list(pkg = "tabnet", fun = "num_steps"), has_submodel = FALSE ) @@ -121,7 +121,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "mask_type", original = "mask_type", - func = list(pkg = "dials", fun = "mask_type"), + func = list(pkg = "tabnet", fun = "mask_type"), has_submodel = FALSE ) @@ -157,7 +157,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_independent", original = "num_independent", - func = list(pkg = "dials", fun = "num_independent"), + func = list(pkg = "tabnet", fun = "num_independent"), has_submodel = FALSE ) @@ -166,7 +166,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "num_shared", original = "num_shared", - func = list(pkg = "dials", fun = "num_shared"), + func = list(pkg = "tabnet", fun = "num_shared"), has_submodel = FALSE ) @@ -202,7 +202,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "feature_reusage", original = "feature_reusage", - func = list(pkg = "dials", fun = "feature_reusage"), + func = list(pkg = "tabnet", fun = "feature_reusage"), has_submodel = FALSE ) @@ -211,7 +211,7 @@ add_parsnip_tabnet <- function() { eng = "torch", parsnip = "momentum", original = "momentum", - func = list(pkg = "dials", fun = "momentum"), + func = list(pkg = "tabnet", fun = "momentum"), has_submodel = FALSE )