-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlovaz1.py
88 lines (74 loc) · 2.93 KB
/
lovaz1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Lovasz-Softmax and Jaccard hinge loss in Tensorflow
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from __future__ import print_function, division
import tensorflow as tf
import numpy as np
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
gts = tf.reduce_sum(gt_sorted)
intersection = gts - tf.cumsum(gt_sorted)
union = gts + tf.cumsum(1. - gt_sorted)
jaccard = 1. - intersection / union
jaccard = tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0)
return jaccard
# --------------------------- BINARY LOSSES ---------------------------
def lovasz_hinge(logits, labels, per_image=False, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
def treat_image(log_lab):
log, lab = log_lab
log, lab = tf.expand_dims(log, 0), tf.expand_dims(lab, 0)
log, lab = flatten_binary_scores(log, lab, ignore)
return lovasz_hinge_flat(log, lab)
losses = tf.map_fn(treat_image, (logits, labels), dtype=tf.float32)
loss = tf.reduce_mean(losses)
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
def compute_loss():
labelsf = tf.cast(labels, logits.dtype)
signs = 2. * labelsf - 1.
errors = 1. - logits * tf.stop_gradient(signs)
errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], name="descending_sort")
gt_sorted = tf.gather(labelsf, perm)
grad = lovasz_grad(gt_sorted)
loss = tf.tensordot(tf.nn.relu(errors_sorted), tf.stop_gradient(grad), 1, name="loss_non_void")
return loss
# deal with the void prediction case (only void pixels)
loss = tf.cond(tf.equal(tf.shape(logits)[0], 0),
lambda: tf.reduce_sum(logits) * 0.,
compute_loss,
name="loss"
)
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = tf.reshape(scores, (-1,))
labels = tf.reshape(labels, (-1,))
if ignore is None:
return scores, labels
valid = tf.not_equal(labels, ignore)
vscores = tf.boolean_mask(scores, valid, name='valid_scores')
vlabels = tf.boolean_mask(labels, valid, name='valid_labels')
return vscores, vlabels