From 81e1c580df2bf96515dd005de58ad045da2e0515 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 08:35:51 +0200 Subject: [PATCH 1/3] Remove unnecessary test calls --- tests/test_coordinate_descent.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_coordinate_descent.py b/tests/test_coordinate_descent.py index 11641c8..7d1dc41 100644 --- a/tests/test_coordinate_descent.py +++ b/tests/test_coordinate_descent.py @@ -102,7 +102,3 @@ def test_coordinate_descent_bounds(): assert np.all(rolch_lasso_path_negative <= 0), "Path should contain only betas <= 0" assert np.all(rolch_lasso_path_positive >= 0), "Path should contain only betas >= 0" - - -test_coordinate_descent() -test_coordinate_descent_bounds() From 7b725ae4652fa8e4788faa6370eeeca19df5d341 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 08:34:32 +0200 Subject: [PATCH 2/3] Add test for coefficients against R --- tests/test_python_against_r.py | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_python_against_r.py diff --git a/tests/test_python_against_r.py b/tests/test_python_against_r.py new file mode 100644 index 0000000..4595b1a --- /dev/null +++ b/tests/test_python_against_r.py @@ -0,0 +1,46 @@ +import numpy as np + +import rolch + +file = "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" +mtcars = np.genfromtxt(file, delimiter=",", skip_header=1)[:, 1:] + +y = mtcars[:, 0] +X = mtcars[:, [1, 3]] +X = np.hstack((np.ones((y.shape[0], 1)), X)) + + +def test_normal_distribution(): + # Run the following R code + # library("gamlss") + # data(mtcars) + + # model = gamlss( + # mpg ~ cyl + hp, + # sigma.formula = ~cyl + hp, + # family=NO(), + # data=as.data.frame(mtcars) + # ) + + # coef(model, "mu") + # coef(model, "sigma") + + # To get these coefficients + coef_R_mu = np.array([36.51776626, -2.32470221, -0.01421071]) + coef_R_sg = np.array([1.8782995906, -0.1262290913, -0.0003943062]) + + estimator = rolch.OnlineGamlss( + distribution=rolch.DistributionNormal(), + method="ols", + do_scale=False, + expect_intercept=True, + rss_tol_inner=10, + ) + estimator.fit(y=y, x0=X, x1=X) + + assert np.allclose( + estimator.betas[0], coef_R_mu, atol=0.01 + ), "Location coefficients don't match" + assert np.allclose( + estimator.betas[1], coef_R_sg, atol=0.01 + ), "Scale coefficients don't match" From a8536d7e586919a0817b9f8c663dac23b4b77708 Mon Sep 17 00:00:00 2001 From: simon-hirsch Date: Fri, 30 Aug 2024 09:23:51 +0200 Subject: [PATCH 3/3] Update gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 131d225..80455c7 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ presentation/* .cache .devcontainer/* .coverage -build.txt \ No newline at end of file +build.txt +gist.githubusercontent* \ No newline at end of file