From c8e3bfb560cb1b5c70f35bcfcf648e418849c7ec Mon Sep 17 00:00:00 2001 From: Mikael Mieskolainen Date: Sun, 28 Jul 2024 19:35:12 +0100 Subject: [PATCH] tune param --- configs/zee/models.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/configs/zee/models.yml b/configs/zee/models.yml index e1c7d3de..5f39b914 100644 --- a/configs/zee/models.yml +++ b/configs/zee/models.yml @@ -233,11 +233,11 @@ lzmlp0: &LZMLP #lossfunc: 'binary_Lq_entropy' #q: 0.8 # Lq exponent (q < 1 -> high density vals emphasized, q > 1 then low emphasized) - SWD_beta: 5.0e-2 # Sliced Wasserstein [reweighting regularization] + SWD_beta: 1.0e-3 # Sliced Wasserstein [reweighting regularization] SWD_p: 1 # p-norm (1,2,..), 1 perhaps more robust - SWD_num_slices: 10000 # Number of MC projections (higher the better) + SWD_num_slices: 1000 # Number of MC projections (higher the better) SWD_mode: 'SWD' # 'SWD' (basic) - SWD_norm_weights: False # Normalization enforced + SWD_norm_weights: True # Normalization enforced lipschitz_beta: 5.0e-5 # lipschitz regularization (use with 'lzmlp') #logit_L1_beta: 1.0e-2 # logit norm reg. ~ beta * torch.sum(|logits|) @@ -322,11 +322,11 @@ fastkan0: &FASTKAN #lossfunc: 'binary_Lq_entropy' # binary_cross_entropy, cross_entropy, focal_entropy, logit_norm_cross_entropy #q: 0.8 # Lq exponent (q < 1 -> high density vals emphasized, q > 1 then low emphasized) - SWD_beta: 5.0e-2 # Sliced Wasserstein [reweighting regularization] + SWD_beta: 1.0e-3 # Sliced Wasserstein [reweighting regularization] SWD_p: 1 # p-norm (1,2,..), 1 perhaps more robust - SWD_num_slices: 10000 # Number of MC projections (higher the better) + SWD_num_slices: 1000 # Number of MC projections (higher the better) SWD_mode: 'SWD' # 'SWD' (basic) - SWD_norm_weights: False # Normalization enforced + SWD_norm_weights: True # Normalization enforced #lipshitz_beta: 1.0e-4 # Lipshitz regularization (use with 'lzmlp') #logit_L1_beta: 1.0e-2 # logit norm reg. ~ beta * torch.sum(|logits|) @@ -415,12 +415,12 @@ dmlp0: &DMLP #lossfunc: 'binary_Lq_entropy' #q: 0.8 # Lq exponent (q < 1 -> high density vals emphasized, q > 1 then low emphasized) - SWD_beta: 5.0e-2 # Sliced Wasserstein [reweighting regularization] + SWD_beta: 1.0e-3 # Sliced Wasserstein [reweighting regularization] SWD_p: 1 # p-norm (1,2,..), 1 perhaps more robust - SWD_num_slices: 10000 # Number of MC projections (higher the better) + SWD_num_slices: 1000 # Number of MC projections (higher the better) SWD_mode: 'SWD' # 'SWD' (basic) - SWD_norm_weights: False # Normalization enforced - + SWD_norm_weights: True # Normalization enforced + #logit_L1_beta: 1.0e-2 # logit norm reg. ~ lambda * torch.sum(|logits|) logit_L2_beta: 5.0e-3 # logit norm reg. ~ lambda * torch.sum(logits**2)