-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
50 lines (40 loc) · 1.69 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from tkinter import W
import albumentations as A
import torch
from torch import nn
from torch.utils.data import DataLoader
from utils import intersection_over_union as iou
class YoloV3Loss(nn.Module):
def __init__(self) -> None:
super().__init__()
self.sigmoid = nn.Sigmoid()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
self.cross_entropy = nn.CrossEntropyLoss()
self.no_obj_weight = 0.5
self.coord_weight = 5
def forward(
self, predictions: torch.Tensor, targets: torch.Tensor, anchors: torch.Tensor
) -> torch.Tensor:
obj = targets[..., 0] == 1
no_obj = targets[..., 0] == 0
no_obj_loss = self.bce(predictions[..., 0:1][no_obj], targets[..., 0:1][no_obj])
anchors = anchors[None, :, None, None, :] # reshape anchors perform operationsj
b_w_h = torch.exp(predictions[..., 3:5]) * anchors # as in paper
b_x_y = self.sigmoid(predictions[..., 1:3][obj]) # + object_cells
iou_scores = iou(
torch.cat([b_x_y, b_w_h[obj]], dim=1), targets[..., 1:5][obj]
).detach()
obj_loss = self.mse(iou_scores, self.sigmoid(predictions[..., 0:1][obj]))
predictions[..., 1:3] = self.sigmoid(predictions[..., 1:3])
targets[..., 3:5] = torch.log(1e-16 + targets[..., 3:5] / anchors)
coord_loss = self.mse(predictions[..., 1:5][obj], targets[..., 1:5][obj])
class_loss = self.cross_entropy(
predictions[..., 5:][obj], targets[..., 5][obj].long()
)
return (
self.no_obj_weight * no_obj_loss
+ self.coord_weight * coord_loss
+ obj_loss
+ class_loss
)