Skip to content

Commit

Permalink
Update xgboostlss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Jan 25, 2025
1 parent 0fbc5c4 commit c7d844d
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions skpro/regression/xgboostlss.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def __init__(
else:
self._n_cpu = n_cpu

def _get_distr(self, distr):
"""Get distribution object from string."""
def _get_xgblss_distr(self, distr):
"""Get xgboostlss distribution object from string."""
import importlib

SKPRO_TO_XGBLSS = {
Expand All @@ -120,6 +120,16 @@ def _get_distr(self, distr):
module = importlib.import_module(module_str)
return getattr(module, object_str)

def _get_skpro_distr(self, distr):
"""Get skpro distribution object from string."""
import importlib

module_str = "skpro.distributions." + distr
object_str = distr

module = importlib.import_module(module_str)
return getattr(module, object_str)

def _fit(self, X, y):
"""Fit regressor to training data.
Expand All @@ -145,10 +155,10 @@ def _fit(self, X, y):

dtrain = xgb.DMatrix(X, label=y, nthread=n_cpu, silent=True)

distr = self._get_distr(self.dist)
xgblss_distr = self._get_xgblss_distr(self.dist)

xgblss = XGBoostLSS(
distr(
xgblss_distr(
stabilization="None",
response_fn="exp",
loss_fn="nll",
Expand Down Expand Up @@ -231,9 +241,9 @@ def _predict_proba(self, X):

y_pred_xgblss = self.xgblss_.predict(dtest, pred_type="parameters")

from skpro.distributions.normal import Normal
skpro_distr = self._get_skpro_distr(self.dist)

y_pred = Normal(
y_pred = skpro_distr(
mu=y_pred_xgblss.iloc[:, [0]].values, # mean is first column
sigma=y_pred_xgblss.iloc[:, [1]].values, # scale is second column
index=index,
Expand Down

0 comments on commit c7d844d

Please sign in to comment.