You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am training a binary classification model to predict whether an image belongs to class 1. For this, I use the following loss function configuration: type='CrossEntropyLoss', use_sigmoid=True.
It appears that pred has the shape (32, 1), while label has the shape (32,). This causes the assertion assert pred.dim() == label.dim() in the binary_cross_entropy function to fail.
When I disable the sigmoid (use_sigmoid=False) and set num_classes=2, the training works without errors. However, the issue persists when using use_sigmoid=True and num_classes=1.
Is this a bug in the implementation of CrossEntropyLoss with use_sigmoid=True, or am I misconfiguring something? Any clarification or suggestions would be greatly appreciated.
Branch
main branch (mmpretrain version)
Describe the bug
I am training a binary classification model to predict whether an image belongs to class 1. For this, I use the following loss function configuration:
type='CrossEntropyLoss', use_sigmoid=True
.Below is the full model configuration:
The dataset structure is as follows:
When starting the training process, I encounter an error related to mismatched dimensions between predictions and labels:
Here are the dimensions of the predictions and labels printed during debugging:
It appears that
pred
has the shape(32, 1)
, whilelabel
has the shape(32,)
. This causes the assertionassert pred.dim() == label.dim()
in thebinary_cross_entropy
function to fail.When I disable the sigmoid (
use_sigmoid=False
) and setnum_classes=2
, the training works without errors. However, the issue persists when usinguse_sigmoid=True
andnum_classes=1
.Is this a bug in the implementation of
CrossEntropyLoss
withuse_sigmoid=True
, or am I misconfiguring something? Any clarification or suggestions would be greatly appreciated.Environment
Other information
No response
The text was updated successfully, but these errors were encountered: