Skip to content

Commit

Permalink
update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianguo99 committed Dec 25, 2023
1 parent 3147527 commit d4d2376
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
4 changes: 2 additions & 2 deletions torchcp/classification/loss/conftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class ConfTr(nn.Module):
Paper: https://arxiv.org/abs/2110.09192.
:param weights: the weight of each loss function
:param weight: the weight of each loss function
:param predictor: the CP predictors
:param alpha: the significance level for each training batch
:param fraction: the fraction of the calibration set in each training batch
:param loss_types: the selected (multi-selected) loss functions, which can be "valid", "classification", "probs", "coverage".
:param loss_type: the selected (multi-selected) loss functions, which can be "valid", "classification", "probs", "coverage".
:param target_size: Optional: 0 | 1.
:param loss_transform: a transform for loss
:param base_loss_fn: a base loss function. For example, cross entropy in classification.
Expand Down
2 changes: 1 addition & 1 deletion torchcp/classification/scores/thr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class THR(BaseScore):
Threshold conformal predictors (Sadinle et al., 2016).
paper : https://arxiv.org/abs/1609.00451.
param score_type: a transformation on logits. Default: "softmax". Optional: "softmax", "Identity", "log_softmax" or "log".
:param score_type: a transformation on logits. Default: "softmax". Optional: "softmax", "Identity", "log_softmax" or "log".
"""

def __init__(self, score_type="softmax") -> None:
Expand Down
14 changes: 8 additions & 6 deletions torchcp/regression/loss/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@

__all__ = ["QuantileLoss"]


class QuantileLoss(nn.Module):
""" Pinball loss function
"""
Pinball loss function (Romano et al., 2019)
Paper: https://proceedings.neurips.cc/paper_files/paper/2019/file/5103c3584b063c431bd1268e9b5e76fb-Paper.pdf
:param quantiles: a list of quantiles, such as $[\frac{alpha}{2}, 1-\frac{alpha}{2}]$.
"""

def __init__(self, quantiles):
"""
A loss to training a quantile-regression model (Romano et al., 2019).
Paper: https://proceedings.neurips.cc/paper_files/paper/2019/file/5103c3584b063c431bd1268e9b5e76fb-Paper.pdf.
:param quantiles: a list of quantiles, such as $[\frac{alpha}{2}, 1-\frac{alpha}{2}]$.
"""
super().__init__()
self.quantiles = quantiles
Expand All @@ -30,7 +32,7 @@ def forward(self, preds, target):
losses = preds.new_zeros(len(self.quantiles))

for i, q in enumerate(self.quantiles):
errors = target - preds[:, i:i+1]
errors = target - preds[:, i:i + 1]
losses[i] = torch.sum(torch.max((q - 1) * errors, q * errors).squeeze(1))
loss = torch.mean(losses)
return loss

0 comments on commit d4d2376

Please sign in to comment.