-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathloss.py
36 lines (27 loc) · 971 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tensorflow as tf
import numpy as np
slim = tf.contrib.slim
def compute_heatmaps_loss(gt_heatmaps, pred_heatmaps, cfg=''):
l = 0.0
for i, pred in enumerate(pred_heatmaps): # For each hourglass unit...
l += tf.nn.l2_loss(gt_heatmaps - pred)
return l
def add_heatmaps_loss(gt_heatmaps, pred_heatmaps, add_summaries, cfg=''):
"""
Args:
gt_heatmaps :
The ground-truth heatmaps.
(should be list of PART_NUMxINPUT_SIZExINPUT_SIZE images/arrays)
pred_heatmaps :
an array of heatmaps with the same shape as gt_heatmaps
"""
total_loss = 0.0
summaries = []
l = 0.0
for i, pred in enumerate(pred_heatmaps): # For each hourglass unit...
l = tf.nn.l2_loss(gt_heatmaps - pred)
tf.compat.v1.losses.add_loss(l)
total_loss += l
if add_summaries:
summaries.append(tf.compat.v1.summary.scalar('heatmap_loss_%d' % i, l))
return total_loss, summaries