From ec2394f5048b1cef8784ec23f82a37c082058f67 Mon Sep 17 00:00:00 2001 From: Haziq Razali Date: Mon, 23 Sep 2019 22:45:46 +0800 Subject: [PATCH] Add files via upload --- openpifpaf/network/__init__.py | 3 + openpifpaf/network/basenetworks.py | 277 ++++++++++++++++++++++ openpifpaf/network/heads.py | 199 ++++++++++++++++ openpifpaf/network/losses.py | 361 +++++++++++++++++++++++++++++ openpifpaf/network/nets.py | 335 ++++++++++++++++++++++++++ openpifpaf/network/trainer.py | 341 +++++++++++++++++++++++++++ 6 files changed, 1516 insertions(+) create mode 100644 openpifpaf/network/__init__.py create mode 100644 openpifpaf/network/basenetworks.py create mode 100644 openpifpaf/network/heads.py create mode 100644 openpifpaf/network/losses.py create mode 100644 openpifpaf/network/nets.py create mode 100644 openpifpaf/network/trainer.py diff --git a/openpifpaf/network/__init__.py b/openpifpaf/network/__init__.py new file mode 100644 index 0000000..1f2782b --- /dev/null +++ b/openpifpaf/network/__init__.py @@ -0,0 +1,3 @@ +from .nets import factory, factory_from_args +from .trainer import Trainer +from . import losses diff --git a/openpifpaf/network/basenetworks.py b/openpifpaf/network/basenetworks.py new file mode 100644 index 0000000..d86ecf1 --- /dev/null +++ b/openpifpaf/network/basenetworks.py @@ -0,0 +1,277 @@ +import copy +import torch + + +class BaseNetwork(torch.nn.Module): + """Common base network.""" + + def __init__(self, net, shortname, input_output_scale, out_features): + super(BaseNetwork, self).__init__() + + self.net = net + self.shortname = shortname + self.input_output_scale = input_output_scale + self.out_features = out_features + self.topology = 'linear' + + # print(list(net.children())) + print('input output scale', self.input_output_scale) + print('output features', self.out_features) + + def forward(self, image): # pylint: disable=arguments-differ + if isinstance(self.net, torch.nn.ModuleList): + if self.topology == 'linear': + intermediate = image + outputs = [] + for n in self.net: + intermediate = n(intermediate) + outputs.append(intermediate) + + return outputs + + if self.topology == 'fork': + intermediate = self.net[0](image) + return intermediate, self.net[1](intermediate), self.net[2](intermediate) + + return self.net(image) + + +class ResnetC4(BaseNetwork): + """Resnet capped after stage4. Default is a Resnet50. + + Spatial resolution of output is input resolution divided by 16. + Has an option to keep stage5. + """ + + def __init__(self, resnet, shortname=None, remove_pool0=True, + input_stride=2, pool0_stride=2, block5=False, + twostage=False, fork=False): + # print('===============') + # print(list(resnet.children())) + + if not block5: + # remove the linear, avgpool2d and stage5 + stump_modules = list(resnet.children())[:-3] + input_output_scale = 16 + out_features = 1024 + else: + # remove linear and avgpool2d + stump_modules = list(resnet.children())[:-2] + input_output_scale = 32 + out_features = 2048 + + if remove_pool0: + stump_modules.pop(3) + input_output_scale /= 2 + else: + if pool0_stride != 2: + stump_modules[3].stride = torch.nn.modules.utils._pair(pool0_stride) # pylint: disable=protected-access + input_output_scale *= pool0_stride / 2 + + if input_stride != 2: + stump_modules[0].stride = torch.nn.modules.utils._pair(input_stride) # pylint: disable=protected-access + input_output_scale *= input_stride / 2 + + if twostage: + stump = torch.nn.ModuleList([ + torch.nn.Sequential(*stump_modules[:-1]), + torch.nn.Sequential(*stump_modules[-1:]), + ]) + elif fork: + stump = torch.nn.ModuleList([ + torch.nn.Sequential(*stump_modules[:-1]), + torch.nn.Sequential(*stump_modules[-1:]), + copy.deepcopy(torch.nn.Sequential(*stump_modules[-1:])), + ]) + else: + stump = torch.nn.Sequential(*stump_modules) + + shortname = shortname or resnet.__class__.__name__ + super(ResnetC4, self).__init__(stump, shortname, input_output_scale, out_features) + if fork: + self.topology = 'fork' + + def atrous0(self, dilation): + convs = [m for m in self.net.modules() if isinstance(m, torch.nn.Conv2d)] + first_conv = convs[0] + + print('before atrous', list(self.net.children())) + print('model: stride = {}, dilation = {}, input_output = {}' + ''.format(first_conv.stride, first_conv.dilation, self.input_output_scale)) + + original_stride = first_conv.stride[0] + first_conv.stride = torch.nn.modules.utils._pair(original_stride // dilation) # pylint: disable=protected-access + first_conv.dilation = torch.nn.modules.utils._pair(dilation) # pylint: disable=protected-access + padding = (first_conv.kernel_size[0] - 1) // 2 * first_conv.dilation[0] + first_conv.padding = torch.nn.modules.utils._pair(padding) # pylint: disable=protected-access + + for conv in convs[1:]: + if conv.kernel_size[0] > 1: + conv.dilation = torch.nn.modules.utils._pair(dilation) # pylint: disable=protected-access + + padding = (conv.kernel_size[0] - 1) // 2 * conv.dilation[0] + conv.padding = torch.nn.modules.utils._pair(padding) # pylint: disable=protected-access + + self.input_output_scale /= dilation + print('after atrous', list(self.net.children())) + print('atrous modification: stride = {}, dilation = {}, input_output = {}' + ''.format(first_conv.stride, first_conv.dilation, self.input_output_scale)) + + def atrous(self, dilations): + """Apply atrous.""" + if isinstance(self.net, tuple): + children = list(self.net[0].children()) + list(self.net[1].children()) + else: + children = list(self.net.children()) + + layer3, layer4 = children[-2:] + print('before layer 3', layer3) + print('before layer 4', layer4) + + prev_dilations = [1] + list(dilations[:-1]) + for prev_dilation, dilation, layer in zip(prev_dilations, dilations, (layer3, layer4)): + if dilation == 1: + continue + + convs = [m for m in layer.modules() if isinstance(m, torch.nn.Conv2d)] + layer_stride = max(c.stride[0] for c in convs) + self.input_output_scale /= layer_stride + + for conv in convs: + if dilation != prev_dilation: + conv.stride = torch.nn.modules.utils._pair(1) # pylint: disable=protected-access + if conv.kernel_size[0] > 1: + conv.dilation = torch.nn.modules.utils._pair(dilation) # pylint: disable=protected-access + + padding = (conv.kernel_size[0] - 1) // 2 * dilation + conv.padding = torch.nn.modules.utils._pair(padding) # pylint: disable=protected-access + + print('after atrous layer 3', layer3) + print('after atrous layer 4', layer4) + + +class DownsampleCat(torch.nn.Module): + def __init__(self): + super(DownsampleCat, self).__init__() + self.pad = torch.nn.ConstantPad2d((0, 1, 0, 1), 0.0) + + def forward(self, x): # pylint: disable=arguments-differ + p = self.pad(x) + o = torch.cat((p[:, :, :-1:2, :-1:2], p[:, :, 1::2, 1::2]), dim=1) + return o + + +class ResnetBlocks(object): + def __init__(self, resnet): + self.modules = list(resnet.children()) + # print('===============') + # print(self.modules) + + def input_block(self, use_pool=False, conv_stride=2, pool_stride=2): + modules = self.modules[:4] + + if not use_pool: + modules.pop(3) + else: + if pool_stride != 2: + modules[3].stride = torch.nn.modules.utils._pair(pool_stride) # pylint: disable=protected-access + + if conv_stride != 2: + modules[0].stride = torch.nn.modules.utils._pair(conv_stride) # pylint: disable=protected-access + + return torch.nn.Sequential(*modules) + + @staticmethod + def dilation(block, dilation, stride=1): + convs = [m for m in block.modules() if isinstance(m, torch.nn.Conv2d)] + + for conv in convs: + if conv.kernel_size[0] == 1: + continue + + conv.dilation = torch.nn.modules.utils._pair(dilation) # pylint: disable=protected-access + + padding = (conv.kernel_size[0] - 1) // 2 * dilation + conv.padding = torch.nn.modules.utils._pair(padding) # pylint: disable=protected-access + + # TODO: check these are the right convolutions to adjust + for conv in convs[:2]: + conv.stride = torch.nn.modules.utils._pair(stride) # pylint: disable=protected-access + + return block + + @staticmethod + def stride(block): + """Compute the output stride of a block. + + Assume that convolutions are in serious with pools; only one + convolutions with non-unit stride. + """ + if isinstance(block, list): + stride = 1 + for b in block: + stride *= ResnetBlocks.stride(b) + return stride + + conv_stride = max(m.stride[0] + for m in block.modules() + if isinstance(m, torch.nn.Conv2d)) + + pool_stride = 1 + pools = [m for m in block.modules() if isinstance(m, torch.nn.MaxPool2d)] + if pools: + for p in pools: + pool_stride *= p.stride + + return conv_stride * pool_stride + + @staticmethod + def replace_downsample(block): + print('!!!!!!!!!!') + first_bottleneck = block[0] + print(first_bottleneck.downsample) + first_bottleneck.downsample = DownsampleCat() + print(first_bottleneck) + + @staticmethod + def out_channels(block): + """For blocks 2-5.""" + last_conv = list(block.modules())[-3] + return last_conv.out_channels + + def block2(self): + return self.modules[4] + + def block3(self): + return self.modules[5] + + def block4(self): + return self.modules[6] + + def block5(self): + return self.modules[7] + + +class DenseNet(BaseNetwork): + """DenseNet. Default is a densenet121. + + Spatial resolution of output is input resolution divided by 16. + """ + + def __init__(self, densenet, shortname=None, remove_pool0=True, adjust_input_stride=False): + # print('===============') + # print(list(densenet.children())) + input_output_scale = 32 + + # remove the last linear layer, and maxpool0 at the beginning + stump_modules = list(list(densenet.children())[0].children())[:-1] + if remove_pool0: + stump_modules.pop(3) + input_output_scale /= 2 + if adjust_input_stride: + stump_modules[0].stride = torch.nn.modules.utils._pair(1) # pylint: disable=protected-access + input_output_scale /= 2 + stump = torch.nn.Sequential(*stump_modules) + + shortname = shortname or densenet.__class__.__name__ + super(DenseNet, self).__init__(stump, shortname, input_output_scale) diff --git a/openpifpaf/network/heads.py b/openpifpaf/network/heads.py new file mode 100644 index 0000000..85b4e36 --- /dev/null +++ b/openpifpaf/network/heads.py @@ -0,0 +1,199 @@ +"""Head networks.""" + +from abc import ABCMeta, abstractstaticmethod +import logging +import re + +import torch + +LOG = logging.getLogger(__name__) + + +class Head(metaclass=ABCMeta): + @abstractstaticmethod + def match(head_name): # pylint: disable=unused-argument + return False + + @classmethod + def cli(cls, parser): + """Add decoder specific command line arguments to the parser.""" + + @classmethod + def apply_args(cls, args): + """Read command line arguments args to set class properties.""" + + +class CompositeField(Head, torch.nn.Module): + default_dropout_p = 0.0 + default_quad = 0 + default_kernel_size = 1 + default_padding = 0 + default_dilation = 1 + + def __init__(self, head_name, in_features, *, + n_fields=None, + n_confidences=1, n_vectors=None, n_scales=None, + kernel_size=None, padding=None, dilation=None): + super(CompositeField, self).__init__() + + n_fields = n_fields or self.determine_nfields(head_name) + n_vectors = n_vectors or self.determine_nvectors(head_name) + n_scales = n_scales or self.determine_nscales(head_name) + LOG.debug('%s loss: fields = %d, confidences = %d, vectors = %d, scales = %d', + head_name, n_fields, n_confidences, n_vectors, n_scales) + + #print("heads.py ", head_name, n_fields, n_vectors, n_scales) + #heads.py pif 17 1 1 + #heads.py paf 19 2 0 + #heads.py crm 2 0 0 + + if kernel_size is None: + kernel_size = {'wpaf': 3}.get(head_name, self.default_kernel_size) + if padding is None: + padding = {'wpaf': 5}.get(head_name, self.default_padding) + if dilation is None: + dilation = {'wpaf': 5}.get(head_name, self.default_dilation) + LOG.debug('%s loss: kernel = %d, padding = %d, dilation = %d', + head_name, kernel_size, padding, dilation) + + self.shortname = head_name + self.apply_class_sigmoid = True + self.dilation = dilation + + self.dropout = torch.nn.Dropout2d(p=self.default_dropout_p) + self._quad = self.default_quad + + # classification + out_features = n_fields * (4 ** self._quad) + self.class_convs = torch.nn.ModuleList([ + torch.nn.Conv2d(in_features, out_features, + kernel_size, padding=padding, dilation=dilation) + for _ in range(n_confidences) + ]) + + # regression + self.reg_convs = torch.nn.ModuleList([ + torch.nn.Conv2d(in_features, 2 * out_features, + kernel_size, padding=padding, dilation=dilation) + for _ in range(n_vectors) + ]) + self.reg_spreads = torch.nn.ModuleList([ + torch.nn.Conv2d(in_features, out_features, + kernel_size, padding=padding, dilation=dilation) + for _ in self.reg_convs + ]) + + # scale + self.scale_convs = torch.nn.ModuleList([ + torch.nn.Conv2d(in_features, out_features, + kernel_size, padding=padding, dilation=dilation) + for _ in range(n_scales) + ]) + + # dequad + self.dequad_op = torch.nn.PixelShuffle(2) + + @staticmethod + def determine_nfields(head_name): + m = re.match('p[ia]f([0-9]+)$', head_name) + if m is not None: + return int(m.group(1)) + + return { + 'paf': 19, + 'pafb': 19, + 'pafsb': 19, + 'pafs19': 19, + 'wpaf': 19, + 'crm': 2, + }.get(head_name, 17) + + @staticmethod + def determine_nvectors(head_name): + if 'pif' in head_name: + return 1 + if 'paf' in head_name: + return 2 + if 'crm' in head_name: + return 0 + return 0 + + @staticmethod + def determine_nscales(head_name): + if 'pif' in head_name: + return 1 + if 'paf' in head_name: + return 0 + if 'crm' in head_name: + return 0 + return 0 + + @staticmethod + def match(head_name): + return head_name in ( + 'pif', + 'paf', + 'pafs', + 'wpaf', + 'pafb', + 'pafs19', + 'pafsb', + 'crm', + ) or re.match('p[ia]f([0-9]+)$', head_name) is not None + + @classmethod + def cli(cls, parser): + group = parser.add_argument_group('head') + group.add_argument('--head-dropout', default=cls.default_dropout_p, type=float, + help='zeroing probability of feature in head input') + group.add_argument('--head-quad', default=cls.default_quad, type=int, + help='number of times to apply quad (subpixel conv) to heads') + group.add_argument('--head-kernel-size', default=cls.default_kernel_size, type=int) + group.add_argument('--head-padding', default=cls.default_padding, type=int) + group.add_argument('--head-dilation', default=cls.default_dilation, type=int) + + @classmethod + def apply_args(cls, args): + cls.default_dropout_p = args.head_dropout + cls.default_quad = args.head_quad + cls.default_kernel_size = args.head_kernel_size + cls.default_padding = args.head_padding + cls.default_dilation = args.head_dilation + + def forward(self, x): # pylint: disable=arguments-differ + x = self.dropout(x) + + # classification + classes_x = [class_conv(x) for class_conv in self.class_convs] + if self.apply_class_sigmoid or self.shortname=="crm": + classes_x = [torch.sigmoid(class_x) for class_x in classes_x] + + # regressions + regs_x = [reg_conv(x) * self.dilation for reg_conv in self.reg_convs] + regs_x_spread = [reg_spread(x) for reg_spread in self.reg_spreads] + regs_x_spread = [torch.nn.functional.leaky_relu(x + 3.0) - 3.0 for x in regs_x_spread] + + # scale + scales_x = [scale_conv(x) for scale_conv in self.scale_convs] + scales_x = [torch.nn.functional.relu(scale_x) for scale_x in scales_x] + + for _ in range(self._quad): + classes_x = [self.dequad_op(class_x)[:, :, :-1, :-1] + for class_x in classes_x] + regs_x = [self.dequad_op(reg_x)[:, :, :-1, :-1] + for reg_x in regs_x] + regs_x_spread = [self.dequad_op(reg_x_spread)[:, :, :-1, :-1] + for reg_x_spread in regs_x_spread] + scales_x = [self.dequad_op(scale_x)[:, :, :-1, :-1] + for scale_x in scales_x] + + regs_x = [ + reg_x.reshape(reg_x.shape[0], + reg_x.shape[1] // 2, + 2, + reg_x.shape[2], + reg_x.shape[3]) + for reg_x in regs_x + ] + + return classes_x + regs_x + regs_x_spread + scales_x diff --git a/openpifpaf/network/losses.py b/openpifpaf/network/losses.py new file mode 100644 index 0000000..b90dc6c --- /dev/null +++ b/openpifpaf/network/losses.py @@ -0,0 +1,361 @@ +"""Losses.""" + +from abc import ABCMeta, abstractstaticmethod +import logging +import re +import torch + +from ..data import (COCO_PERSON_SIGMAS, COCO_PERSON_SKELETON, KINEMATIC_TREE_SKELETON, + DENSER_COCO_PERSON_SKELETON) + +LOG = logging.getLogger(__name__) + + +class Loss(metaclass=ABCMeta): + @abstractstaticmethod + def match(head_name): # pylint: disable=unused-argument + return False + + @classmethod + def cli(cls, parser): + """Add decoder specific command line arguments to the parser.""" + + @classmethod + def apply_args(cls, args): + """Read command line arguments args to set class properties.""" + + +def laplace_loss(x1, x2, logb, t1, t2, weight=None): + """Loss based on Laplace Distribution. + + Loss for a single two-dimensional vector (x1, x2) with radial + spread b and true (t1, t2) vector. + """ + norm = torch.sqrt((x1 - t1)**2 + (x2 - t2)**2) + losses = 0.694 + logb + norm * torch.exp(-logb) + if weight is not None: + losses = losses * weight + return torch.sum(losses) + + +def l1_loss(x1, x2, _, t1, t2, weight=None): + """L1 loss. + + Loss for a single two-dimensional vector (x1, x2) + true (t1, t2) vector. + """ + losses = torch.sqrt((x1 - t1)**2 + (x2 - t2)**2) + if weight is not None: + losses = losses * weight + return torch.sum(losses) + + +class SmoothL1Loss(object): + def __init__(self, r_smooth, scale_required=True): + self.r_smooth = r_smooth + self.scale = None + self.scale_required = scale_required + + def __call__(self, x1, x2, _, t1, t2, weight=None): + """L1 loss. + + Loss for a single two-dimensional vector (x1, x2) + true (t1, t2) vector. + """ + if self.scale_required and self.scale is None: + raise Exception + if self.scale is None: + self.scale = 1.0 + + r = self.r_smooth * self.scale + d = torch.sqrt((x1 - t1)**2 + (x2 - t2)**2) + smooth_regime = d < r + + smooth_loss = 0.5 / r[smooth_regime] * d[smooth_regime] ** 2 + linear_loss = d[smooth_regime == 0] - (0.5 * r[smooth_regime == 0]) + losses = torch.cat((smooth_loss, linear_loss)) + + if weight is not None: + losses = losses * weight + + self.scale = None + return torch.sum(losses) + + +class MultiHeadLoss(torch.nn.Module): + def __init__(self, losses, lambdas): + super(MultiHeadLoss, self).__init__() + + self.losses = torch.nn.ModuleList(losses) + self.losses_pifpaf = self.losses[:2] + self.losses_crm = self.losses[2:] + self.lambdas = lambdas + + def forward(self, head_fields, head_targets, head): # pylint: disable=arguments-differ + + if head=="pifpaf": + + assert len(self.losses_pifpaf) == len(head_fields) + assert len(self.losses_pifpaf) <= len(head_targets) + flat_head_losses = [ll + for l, f, t in zip(self.losses_pifpaf, head_fields, head_targets) + for ll in l(f, t)] + assert len(self.lambdas) == len(flat_head_losses) + loss_values = [lam * l + for lam, l in zip(self.lambdas, flat_head_losses) + if l is not None] + total_loss = sum(loss_values) if loss_values else None + + if head=="crm": + + assert len(self.losses_crm) == len(head_fields) + assert len(self.losses_crm) <= len(head_targets) + flat_head_losses = [ll + for l, f, t in zip(self.losses_crm, head_fields, head_targets) + for ll in l(f, t)] + loss_values = [0.2 * l + for l in flat_head_losses + if l is not None] + total_loss = sum(loss_values) if loss_values else None + return total_loss, flat_head_losses + +class CompositeLoss(Loss, torch.nn.Module): + default_background_weight = 1.0 + default_multiplicity_correction = False + default_independence_scale = 3.0 + + def __init__(self, head_name, regression_loss, *, + n_vectors=None, n_scales=None, sigmas=None): + super(CompositeLoss, self).__init__() + + if n_vectors is None and 'pif' in head_name: + n_vectors = 1 + if n_vectors is None and 'paf' in head_name: + n_vectors = 2 + + if n_scales is None and 'pif' in head_name: + n_scales = 1 + if n_scales is None and 'paf' in head_name: + n_scales = 0 + + if sigmas is None and head_name == 'pif': + sigmas = [COCO_PERSON_SIGMAS] + if sigmas is None and 'pif' in head_name: + sigmas = [[1.0]] + if sigmas is None and head_name in ('paf', 'paf19', 'wpaf'): + sigmas = [ + [COCO_PERSON_SIGMAS[j1i - 1] for j1i, _ in COCO_PERSON_SKELETON], + [COCO_PERSON_SIGMAS[j2i - 1] for _, j2i in COCO_PERSON_SKELETON], + ] + if sigmas is None and head_name in ('paf16',): + sigmas = [ + [COCO_PERSON_SIGMAS[j1i - 1] for j1i, _ in KINEMATIC_TREE_SKELETON], + [COCO_PERSON_SIGMAS[j2i - 1] for _, j2i in KINEMATIC_TREE_SKELETON], + ] + if sigmas is None and head_name in ('paf44',): + sigmas = [ + [COCO_PERSON_SIGMAS[j1i - 1] for j1i, _ in DENSER_COCO_PERSON_SKELETON], + [COCO_PERSON_SIGMAS[j2i - 1] for _, j2i in DENSER_COCO_PERSON_SKELETON], + ] + + self.background_weight = self.default_background_weight + self.multiplicity_correction = self.default_multiplicity_correction + self.independence_scale = self.default_independence_scale + + self.n_vectors = n_vectors + self.n_scales = n_scales + if self.n_scales: + assert len(sigmas) == n_scales + + #print("losses.py ", head_name, sigmas) + + LOG.debug('%s: n_vectors = %d, n_scales = %d, len(sigmas) = %d', + head_name, n_vectors, n_scales, len(sigmas)) + + if sigmas is not None: + assert len(sigmas) == n_vectors + scales_to_kp = torch.tensor(sigmas) + scales_to_kp = torch.unsqueeze(scales_to_kp, 0) + scales_to_kp = torch.unsqueeze(scales_to_kp, -1) + scales_to_kp = torch.unsqueeze(scales_to_kp, -1) + self.register_buffer('scales_to_kp', scales_to_kp) + else: + self.scales_to_kp = None + + self.regression_loss = regression_loss or laplace_loss + + @staticmethod + def match(head_name): + return head_name in ( + 'pif', + 'paf', + 'pafs', + 'wpaf', + 'pafb', + 'pafs19', + 'pafsb', + ) or re.match('p[ia]f([0-9]+)$', head_name) is not None + + @classmethod + def cli(cls, parser): + # group = parser.add_argument_group('composite loss') + pass + + @classmethod + def apply_args(cls, args): + cls.default_background_weight = args.background_weight + cls.default_fixed_size = args.paf_fixed_size + cls.default_aspect_ratio = args.paf_aspect_ratio + + def forward(self, x, t): # pylint: disable=arguments-differ + + assert len(x) == 1 + 2 * self.n_vectors + self.n_scales + x_intensity = x[0] + x_regs = x[1:1 + self.n_vectors] + x_spreads = x[1 + self.n_vectors:1 + 2 * self.n_vectors] + x_scales = [] + if self.n_scales: + x_scales = x[1 + 2 * self.n_vectors:1 + 2 * self.n_vectors + self.n_scales] + + assert len(t) == 1 + self.n_vectors + 1 + target_intensity = t[0] + target_regs = t[1:1 + self.n_vectors] + target_scale = t[-1] + + bce_masks = torch.sum(target_intensity, dim=1, keepdim=True) > 0.5 + if torch.sum(bce_masks) < 1: + return None, None, None + + batch_size = x_intensity.shape[0] + + bce_target = torch.masked_select(target_intensity[:, :-1], bce_masks) + bce_weight = torch.ones_like(bce_target) + bce_weight[bce_target == 0] = self.background_weight + ce_loss = torch.nn.functional.binary_cross_entropy_with_logits( + torch.masked_select(x_intensity, bce_masks), + bce_target, + weight=bce_weight, + ) + + reg_losses = [None for _ in target_regs] + reg_masks = target_intensity[:, :-1] > 0.5 + if torch.sum(reg_masks) > 0: + weight = None + if self.multiplicity_correction: + assert len(target_regs) == 2 + lengths = torch.norm(target_regs[0] - target_regs[1], dim=2) + multiplicity = (lengths - 3.0) / self.independence_scale + multiplicity = torch.clamp(multiplicity, min=1.0) + multiplicity = torch.masked_select(multiplicity, reg_masks) + weight = 1.0 / multiplicity + + reg_losses = [] + for i, (x_reg, x_spread, target_reg) in enumerate(zip(x_regs, x_spreads, target_regs)): + if hasattr(self.regression_loss, 'scale'): + assert self.scales_to_kp is not None + self.regression_loss.scale = torch.masked_select( + torch.clamp(target_scale * self.scales_to_kp[i], 0.1, 1000.0), # pylint: disable=unsubscriptable-object + reg_masks, + ) + + reg_losses.append(self.regression_loss( + torch.masked_select(x_reg[:, :, 0], reg_masks), + torch.masked_select(x_reg[:, :, 1], reg_masks), + torch.masked_select(x_spread, reg_masks), + torch.masked_select(target_reg[:, :, 0], reg_masks), + torch.masked_select(target_reg[:, :, 1], reg_masks), + weight=weight, + ) / 1000.0 / batch_size) + + scale_losses = [] + if x_scales: + scale_losses = [ + torch.nn.functional.l1_loss( + torch.masked_select(x_scale, reg_masks), + torch.masked_select(target_scale * scale_to_kp, reg_masks), + reduction='sum', + ) / 1000.0 / batch_size + for x_scale, scale_to_kp in zip(x_scales, self.scales_to_kp) + ] + + return [ce_loss] + reg_losses + scale_losses + +class CrmLoss(Loss, torch.nn.Module): + + def __init__(self, head_name, regression_loss, *, + n_vectors=None, n_scales=None, sigmas=None): + super(CrmLoss, self).__init__() + + @staticmethod + def match(head_name): + return head_name in ( + 'crm', + ) + @classmethod + def cli(cls, parser): + return + @classmethod + def apply_args(cls, args): + return + + def forward(self, x, t): # pylint: disable=arguments-differ + + activity_loss = torch.nn.functional.mse_loss(x[0], t[0], reduction='sum') / x[0].shape[0] + return [activity_loss] + +def cli(parser): + group = parser.add_argument_group('losses') + group.add_argument('--r-smooth', type=float, default=0.0, + help='r_{smooth} for SmoothL1 regressions') + group.add_argument('--regression-loss', default='laplace', + choices=['smoothl1', 'smootherl1', 'l1', 'laplace'], + help='type of regression loss') + group.add_argument('--background-weight', default=1.0, type=float, + help='BCE weight of background') + group.add_argument('--paf-multiplicity-correction', + default=False, action='store_true', + help='use multiplicity correction for PAF loss') + group.add_argument('--paf-independence-scale', default=3.0, type=float, + help='linear length scale of independence for PAF regression') + + +def factory_from_args(args): + for loss in Loss.__subclasses__(): + loss.apply_args(args) + + return factory( + args.headnets, + args.lambdas, + args.regression_loss, + args.r_smooth, + ).to(device=args.device) + + +def factory(head_names, lambdas, reg_loss_name=None, r_smooth=None): + head_names = [h for h in head_names if h not in ('skeleton',)] + + if reg_loss_name == 'smoothl1': + reg_loss = SmoothL1Loss(r_smooth) + elif reg_loss_name == 'l1': + reg_loss = l1_loss + elif reg_loss_name == 'laplace': + reg_loss = laplace_loss + elif reg_loss_name is None: + reg_loss = laplace_loss + else: + raise Exception('unknown regression loss type {}'.format(reg_loss_name)) + + losses = [factory_loss(head_name, reg_loss) for head_name in head_names] + return MultiHeadLoss(losses, lambdas) + + +def factory_loss(head_name, reg_loss): + for loss in Loss.__subclasses__(): + logging.debug('checking whether loss %s matches %s', + loss.__name__, head_name) + if not loss.match(head_name): + continue + logging.info('selected loss %s for %s', loss.__name__, head_name) + return loss(head_name, reg_loss) + + raise Exception('unknown headname {} for loss'.format(head_name)) diff --git a/openpifpaf/network/nets.py b/openpifpaf/network/nets.py new file mode 100644 index 0000000..0e52d3e --- /dev/null +++ b/openpifpaf/network/nets.py @@ -0,0 +1,335 @@ +import logging +import itertools +import torch +import torchvision +import torchvision.models as models + +from . import basenetworks, heads + +# generate hash values with: shasum -a 256 filename.pkl + +# DEFAULT_MODEL = ('https://documents.epfl.ch/users/k/kr/kreiss/www/' +# 'resnet101block5-pif-paf-edge401-190313-100107-81e34321.pkl') +# DEFAULT_MODEL = ('https://documents.epfl.ch/users/k/kr/kreiss/www/' +# 'resnet50block5-pif-paf-edge401-190315-214317-8c9fbafe.pkl') + +RESNET50_MODEL = ('https://storage.googleapis.com/openpifpaf-pretrained/v0.5.0/' + 'resnet50block5-pif-paf-edge401-190424-122009-f26a1f53.pkl') +RESNET101_MODEL = ('https://storage.googleapis.com/openpifpaf-pretrained/v0.5.0/' + 'resnet101block5-pif-paf-edge401-190412-151013-513a2d2d.pkl') +RESNET152_MODEL = ('https://storage.googleapis.com/openpifpaf-pretrained/v0.5.0/' + 'resnet152block5-pif-paf-edge401-190412-121848-8d771fcc.pkl') + +LOG = logging.getLogger(__name__) + + +class Shell(torch.nn.Module): + def __init__(self, base_net, head_nets): + super(Shell, self).__init__() + + self.base_net = base_net + self.head_nets = torch.nn.ModuleList(head_nets) + + def io_scales(self): + return [self.base_net.input_output_scale // (2 ** getattr(h, '_quad', 0)) + for h in self.head_nets] + + def forward(self, x, head): # pylint: disable=arguments-differ + x = self.base_net(x) + + if head=="pifpaf": + return [hn(x) for hn in self.head_nets[:2]] + if head=="crm": + return [hn(x) for hn in self.head_nets[2:]] + +class Shell2Stage(torch.nn.Module): + def __init__(self, base_net, head_nets1, head_nets2): + super(Shell2Stage, self).__init__() + + self.base_net = base_net + self.head_nets1 = torch.nn.ModuleList(head_nets1) + self.head_nets2 = torch.nn.ModuleList(head_nets2) + + @property + def head_nets(self): + return list(self.head_nets1) + list(self.head_nets2) + + def io_scales(self): + return ( + [self.base_net.input_output_scale[0] for _ in self.head_nets1] + + [self.base_net.input_output_scale[1] for _ in self.head_nets2] + ) + + def forward(self, x): # pylint: disable=arguments-differ + x1, x2 = self.base_net(x) + h1 = [hn(x1) for hn in self.head_nets1] + h2 = [hn(x2) for hn in self.head_nets2] + return [h for hs in (h1, h2) for h in hs] + + +class ShellFork(torch.nn.Module): + def __init__(self, base_net, head_nets1, head_nets2, head_nets3): + super(ShellFork, self).__init__() + + self.base_net = base_net + self.head_nets1 = torch.nn.ModuleList(head_nets1) + self.head_nets2 = torch.nn.ModuleList(head_nets2) + self.head_nets3 = torch.nn.ModuleList(head_nets3) + + @property + def head_nets(self): + return list(self.head_nets1) + list(self.head_nets2) + list(self.head_nets3) + + def io_scales(self): + return ( + [self.base_net.input_output_scale[0] for _ in self.head_nets1] + + [self.base_net.input_output_scale[1] for _ in self.head_nets2] + + [self.base_net.input_output_scale[2] for _ in self.head_nets3] + ) + + def forward(self, x): # pylint: disable=arguments-differ + x1, x2, x3 = self.base_net(x) + h1 = [hn(x1) for hn in self.head_nets1] + h2 = [hn(x2) for hn in self.head_nets2] + h3 = [hn(x3) for hn in self.head_nets3] + return [h for hs in (h1, h2, h3) for h in hs] + + +def factory_from_args(args): + for head in heads.Head.__subclasses__(): + head.apply_args(args) + + return factory(checkpoint=args.checkpoint, + basenet=args.basenet, + headnets=args.headnets, + pretrained=args.pretrained) + + +# pylint: disable=too-many-branches +def model_migration(net_cpu): + for m in net_cpu.modules(): + if not isinstance(m, torch.nn.Conv2d): + continue + if not hasattr(m, 'padding_mode'): # introduced in PyTorch 1.1.0 + m.padding_mode = 'zeros' + + for head in net_cpu.head_nets: + head.shortname = head.shortname.replace('PartsIntensityFields', 'pif') + head.shortname = head.shortname.replace('PartsAssociationFields', 'paf') + if not hasattr(head, 'dropout') or head.dropout is None: + head.dropout = torch.nn.Dropout2d(p=0.0) + if not hasattr(head, '_quad'): + if hasattr(head, 'quad'): + head._quad = head.quad # pylint: disable=protected-access + else: + head._quad = 0 # pylint: disable=protected-access + if not hasattr(head, 'scale_conv'): + head.scale_conv = None + if not hasattr(head, 'reg1_spread'): + head.reg1_spread = None + if not hasattr(head, 'reg2_spread'): + head.reg2_spread = None + if head.shortname == 'pif17' and getattr(head, 'scale_conv') is not None: + head.shortname = 'pifs17' + if head._quad == 1 and not hasattr(head, 'dequad_op'): # pylint: disable=protected-access + head.dequad_op = torch.nn.PixelShuffle(2) + if not hasattr(head, 'class_convs') and hasattr(head, 'class_conv'): + head.class_convs = torch.nn.ModuleList([head.class_conv]) + + +# pylint: disable=too-many-branches +def factory(*, + checkpoint=None, + basenet=None, + headnets=('pif', 'paf'), + pretrained=True, + dilation=None, + dilation_end=None): + if not checkpoint and basenet: + print("Building from scratch") + net_cpu = factory_from_scratch(basenet, headnets, pretrained=pretrained) + epoch = 0 + else: + if not checkpoint: + checkpoint = torch.utils.model_zoo.load_url(RESNET50_MODEL) + elif checkpoint == 'resnet50': + checkpoint = torch.utils.model_zoo.load_url(RESNET50_MODEL) + elif checkpoint == 'resnet101': + checkpoint = torch.utils.model_zoo.load_url(RESNET101_MODEL) + elif checkpoint == 'resnet152': + checkpoint = torch.utils.model_zoo.load_url(RESNET152_MODEL) + elif checkpoint.startswith('http'): + checkpoint = torch.utils.model_zoo.load_url(checkpoint) + else: + print("Loading pretrained model") + checkpoint = torch.load(checkpoint) + net_cpu = checkpoint['model'] + epoch = checkpoint['epoch'] + + # initialize for eval + net_cpu.eval() + for head in net_cpu.head_nets: + head.apply_class_sigmoid = True + + # normalize for backwards compatibility + model_migration(net_cpu) + + if dilation is not None: + net_cpu.base_net.atrous0(dilation) + # for head in net_cpu.head_nets: + # head.dilation = dilation + if dilation_end is not None: + if dilation_end == 1: + net_cpu.base_net.atrous((1, 1)) + elif dilation_end == 2: + net_cpu.base_net.atrous((1, 2)) + elif dilation_end == 4: + net_cpu.base_net.atrous((2, 4)) + else: + raise Exception + # for head in net_cpu.head_nets: + # head.dilation = (dilation or 1.0) * dilation_end + + return net_cpu, epoch + + +# pylint: disable=too-many-branches +def factory_from_scratch(basename, headnames, *, pretrained=True): + if 'resnet50' in basename: + base_vision = torchvision.models.resnet50(pretrained) + elif 'resnet101' in basename: + base_vision = torchvision.models.resnet101(pretrained) + elif 'resnet152' in basename: + base_vision = torchvision.models.resnet152(pretrained) + elif 'resnet260' in basename: + assert pretrained is False + base_vision = torchvision.models.ResNet( + torchvision.models.resnet.Bottleneck, [3, 8, 72, 3]) + # elif basename == 'densenet121': + # basenet = basenetworks.DenseNet(torchvision.models.densenet121(pretrained), 'DenseNet121') + # else: + # raise Exception('basenet not supported') + else: + raise Exception('unknown base network in {}'.format(basename)) + + #if(0): + # print("Loading pifpaf pretrained basenet") + # pifpaf_pretrained_basenet = models.resnet50(pretrained=True) + # pifpaf_pretrained_basenet = torch.nn.Sequential(*(list(pifpaf_pretrained_basenet.children())[:-2])) + # pifpaf_pretrained_basenet.load_state_dict(torch.load("./outputs/pifpaf-resnet50-new.pt")) + # for pair in itertools.zip_longest(base_vision.state_dict().items(), pifpaf_pretrained_basenet.state_dict().items()): + # + # if(pair[1] is None): + # continue + # + # key_base_vision = pair[0][0] + # key_pifpaf_pretrained_basenet = pair[1][0] + # print(key_base_vision, key_pifpaf_pretrained_basenet) + # + # # print(key_pifpaf, key_resnet) + # # sanity check + # if(torch.all(torch.eq(base_vision.state_dict()[key_base_vision], pifpaf_pretrained_basenet.state_dict()[key_pifpaf_pretrained_basenet]))): + # print("check if value not supposed to be equal") + # + # #basenet.state_dict()[key_resnet] = net_cpu.state_dict()[key_pifpaf] + # base_vision.state_dict()[key_base_vision].copy_(pifpaf_pretrained_basenet.state_dict()[key_pifpaf_pretrained_basenet]) + # + # # should always be true + # if(not torch.all(torch.eq(base_vision.state_dict()[key_base_vision], pifpaf_pretrained_basenet.state_dict()[key_pifpaf_pretrained_basenet]))): + # print("should always be true!") + # print() + + resnet_factory = basenetworks.ResnetBlocks(base_vision) + + # input block + use_pool = 'pool0' in basename + conv_stride = 2 + if 'is4' in basename: + conv_stride = 4 + if 'is1' in basename: + conv_stride = 1 + pool_stride = 2 + if 'pool0s4' in basename: + pool_stride = 4 + + # all blocks + blocks = [ + resnet_factory.input_block(use_pool, conv_stride, pool_stride), + resnet_factory.block2(), + resnet_factory.block3(), + resnet_factory.block4(), + ] + if 'block5' in basename: + blocks.append(resnet_factory.block5()) + + # downsample + if 'concat' in basename: + for b in blocks[2:]: + resnet_factory.replace_downsample(b) + + def create_headnet(name, n_features): # pylint: disable=too-many-return-statements + + #print(name, n_features) + # pif 2048 + # paf 2048 + + for head in heads.Head.__subclasses__(): + logging.debug('checking whether head %s matches %s', + head.__name__, name) + if not head.match(name): + continue + logging.info('selected head %s for %s', head.__name__, name) + return head(name, n_features) + + raise Exception('unknown head to create an encoder: {}'.format(name)) + + if 'pifb' in headnames or 'pafb' in headnames: + basenet = basenetworks.BaseNetwork( + torch.nn.ModuleList([torch.nn.Sequential(*blocks[:-1]), blocks[-1]]), + basename, + [resnet_factory.stride(blocks[:-1]), resnet_factory.stride(blocks)], + [resnet_factory.out_channels(blocks[-2]), resnet_factory.out_channels(blocks[-1])], + ) + head1 = [create_headnet(h, basenet.out_features[0]) + for h in headnames if h.endswith('b')] + head2 = [create_headnet(h, basenet.out_features[1]) + for h in headnames if not h.endswith('b')] + return Shell2Stage(basenet, head1, head2) + + if 'ppif' in headnames: + # TODO + head2 = [create_headnet(h, basenet.out_features[1]) + for h in headnames if h == 'ppif'] + head3 = [create_headnet(h, basenet.out_features[2]) + for h in headnames if h != 'ppif'] + return ShellFork(basenet, [], head2, head3) + + basenet = basenetworks.BaseNetwork( + torch.nn.Sequential(*blocks), + basename, + resnet_factory.stride(blocks), + resnet_factory.out_channels(blocks[-1]), + ) + headnets = [create_headnet(h, basenet.out_features) for h in headnames if h != 'skeleton'] + return Shell(basenet, headnets) + + +def cli(parser): + group = parser.add_argument_group('network configuration') + group.add_argument('--checkpoint', default=None, + help=('Load a model from a checkpoint. ' + 'Use "resnet50", "resnet101" ' + 'or "resnet152" for pretrained OpenPifPaf models.')) + group.add_argument('--dilation', default=None, type=int, + help='apply atrous') + group.add_argument('--dilation-end', default=None, type=int, + help='apply atrous') + group.add_argument('--basenet', default=None, + help='base network, e.g. resnet50block5') + group.add_argument('--headnets', default=['pif', 'paf'], nargs='+', + help='head networks') + group.add_argument('--no-pretrain', dest='pretrained', default=True, action='store_false', + help='create model without ImageNet pretraining') + + for head in heads.Head.__subclasses__(): + head.cli(parser) diff --git a/openpifpaf/network/trainer.py b/openpifpaf/network/trainer.py new file mode 100644 index 0000000..ceb6b2a --- /dev/null +++ b/openpifpaf/network/trainer.py @@ -0,0 +1,341 @@ +"""Train a pifpaf net.""" + +import copy +import hashlib +import logging +import shutil +import time +import torch +import itertools +torch.set_printoptions(precision=10) + +class Trainer(object): + def __init__(self, model, loss, optimizer, out, *, + lr_scheduler=None, + log_interval=1, + device=None, + fix_batch_norm=False, + stride_apply=1, + ema_decay=None, + encoder_visualizer=None, + train_profile=None, + model_meta_data=None): + self.log = logging.getLogger(self.__class__.__name__) + + self.model = model + self.loss = loss + self.optimizer = optimizer + self.out = out + self.lr_scheduler = lr_scheduler + + self.log_interval = log_interval + self.device = device + self.fix_batch_norm = fix_batch_norm + self.stride_apply = stride_apply + + self.ema_decay = ema_decay + self.ema = None + self.ema_restore_params = None + + self.encoder_visualizer = encoder_visualizer + self.model_meta_data = model_meta_data + + if train_profile: + # monkey patch to profile self.train_batch() + self.train_batch_without_profile = self.train_batch + def train_batch_with_profile(*args, **kwargs): + with torch.autograd.profiler.profile() as prof: + result = self.train_batch_without_profile(*args, **kwargs) + print(prof.key_averages()) + print(prof.total_average()) + prof.export_chrome_trace(train_profile) + return result + self.train_batch = train_batch_with_profile + + def lr(self): + for param_group in self.optimizer.param_groups: + return param_group['lr'] + + def step_ema(self): + if self.ema is None: + return + + for p, ema_p in zip(self.model.parameters(), self.ema): + ema_p.mul_(1.0 - self.ema_decay).add_(self.ema_decay, p.data) + + def apply_ema(self): + if self.ema is None: + return + + self.log.info('applying ema') + self.ema_restore_params = copy.deepcopy( + [p.data for p in self.model.parameters()]) + for p, ema_p in zip(self.model.parameters(), self.ema): + p.data.copy_(ema_p) + + def ema_restore(self): + if self.ema_restore_params is None: + return + + self.log.info('restoring params from before ema') + for p, ema_p in zip(self.model.parameters(), self.ema_restore_params): + p.data.copy_(ema_p) + self.ema_restore_params = None + + def loop(self, train_scenes, val_scenes, epochs, start_epoch=0): + for _ in range(start_epoch): + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + for epoch in range(start_epoch, epochs): + self.train(train_scenes, epoch) + + self.write_model(epoch + 1, epoch == epochs - 1) + self.val(val_scenes, epoch + 1) + + def train_batch(self, data1, targets1, meta1, data2, targets2, meta2, apply_gradients=True): # pylint: disable=method-hidden + + if self.encoder_visualizer: + self.encoder_visualizer(data1, targets1, meta1) + + if self.device: + data1 = data1.to(self.device, non_blocking=True) + targets1 = [[t.to(self.device, non_blocking=True) for t in head] for head in targets1] + data2 = data2.to(self.device, non_blocking=True) + targets2 = [[t.to(self.device, non_blocking=True) for t in head] for head in targets2] + + if(1): + outputs1 = self.model(data1, head="pifpaf") + loss1, head_losses1 = self.loss(outputs1, targets1, head="pifpaf") + if loss1 is not None: + loss1.backward() + if apply_gradients: + self.optimizer.step() + self.optimizer.zero_grad() + self.step_ema() + + outputs2 = self.model(data2, head="crm") + loss2, head_losses2 = self.loss(outputs2, targets2, head="crm") + if loss2 is not None: + loss2.backward() + if apply_gradients: + self.optimizer.step() + self.optimizer.zero_grad() + self.step_ema() + + outputs3 = self.model(data1, head="pifpaf") + loss3, head_losses3 = self.loss(outputs3, targets1, head="pifpaf") + if loss3 is not None: + loss3.backward() + if apply_gradients: + self.optimizer.step() + self.optimizer.zero_grad() + self.step_ema() + + loss = loss3 + loss2 + head_losses = head_losses3 + head_losses2 + + if(0): + outputs1 = self.model(data1, head="pifpaf") + outputs2 = self.model(data2, head="crm") + loss1, head_losses1 = self.loss(outputs1, targets1, head="pifpaf") + loss2, head_losses2 = self.loss(outputs2, targets2, head="crm") + loss = loss1 + loss2 + head_losses = head_losses1 + head_losses2 + if loss is not None: + loss.backward() + if apply_gradients: + self.optimizer.step() + self.optimizer.zero_grad() + self.step_ema() + + return ( + float(loss.item()) if loss is not None else None, + [float(l.item()) if l is not None else None + for l in head_losses], + ) + + def val_batch(self, data, targets, head): + if self.device: + data = data.to(self.device, non_blocking=True) + targets = [[t.to(self.device, non_blocking=True) for t in head] for head in targets] + + with torch.no_grad(): + outputs = self.model(data, head=head) + loss, head_losses = self.loss(outputs, targets, head=head) + #outputs1 = self.model(data1, head="pifpaf") + #loss1, head_losses1 = self.loss(outputs1, targets1, head="pifpaf") + #outputs2 = self.model(data2, head="crm") + #loss2, head_losses2 = self.loss(outputs2, targets2, head="crm") + #loss = loss1 + loss2 + #head_losses = head_losses1 + head_losses2 + + return ( + float(loss.item()) if loss is not None else None, + [float(l.item()) if l is not None else None + for l in head_losses], + ) + + def train(self, scenes, epoch): + + #print(len(scenes[0])) + #sys.exit(0) + + start_time = time.time() + self.model.train() + if self.fix_batch_norm: + for m in self.model.modules(): + if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)): + # print('fixing parameters for {}. Min var = {}'.format( + # m, torch.min(m.running_var))) + m.eval() + # m.weight.requires_grad = False + # m.bias.requires_grad = False + + # avoid numerical instabilities + # (only seen sometimes when training with GPU) + # Variances in pretrained models can be as low as 1e-17. + # m.running_var.clamp_(min=1e-8) + m.eps = 1e-4 + self.ema_restore() + self.ema = None + + epoch_loss = 0.0 + head_epoch_losses = None + last_batch_end = time.time() + self.optimizer.zero_grad() + + #jaad has a larger number of iterations per epoch so we cycle over coco (scenes[0]) + #for batch_idx, ((data1, target1, meta1),(data2, target2, meta2)) in enumerate(zip(itertools.cycle(scenes[0]), scenes[1])): + for batch_idx, ((data1, target1, meta1),(data2, target2, meta2)) in enumerate(zip(scenes[0],scenes[1])): + + preprocess_time = time.time() - last_batch_end + + batch_start = time.time() + apply_gradients = batch_idx % self.stride_apply == 0 + loss, head_losses = self.train_batch(data1, target1, meta1, data2, target2, meta2, apply_gradients) + + # update epoch accumulates + if loss is not None: + epoch_loss += loss + if head_epoch_losses is None: + head_epoch_losses = [0.0 for _ in head_losses] + for i, head_loss in enumerate(head_losses): + if head_loss is None: + continue + head_epoch_losses[i] += head_loss + + batch_time = time.time() - batch_start + + # write training loss + if batch_idx % self.log_interval == 0: + self.log.info({ + 'type': 'train', + 'epoch': epoch, 'batch': batch_idx, 'n_batches': len(scenes[0]), + 'time': round(batch_time, 3), + 'data_time': round(preprocess_time, 3), + 'lr': self.lr(), + 'loss': round(loss, 3) if loss is not None else None, + 'head_losses': [round(l, 3) if l is not None else None + for l in head_losses], + }) + + # initialize ema + if self.ema is None and self.ema_decay: + self.ema = copy.deepcopy([p.data for p in self.model.parameters()]) + + last_batch_end = time.time() + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + self.apply_ema() + self.log.info({ + 'type': 'train-epoch', + 'epoch': epoch + 1, + 'loss': round(epoch_loss / len(scenes[0]), 5), + 'head_losses': [round(l / len(scenes[0]), 5) for l in head_epoch_losses], + 'time': round(time.time() - start_time, 1), + }) + + def val(self, scenes, epoch): + + #print(len(scenes[0]), len(scenes[1])) + #sys.exit(0) + + start_time = time.time() + self.model.eval() + + epoch_loss1 = 0.0 + epoch_loss2 = 0.0 + head_epoch_losses_pifpaf = None + head_epoch_losses_crm = None + + # cannot cycle when validating + for data1, target1, _ in scenes[0]: + loss, head_losses = self.val_batch(data1, target1, head="pifpaf") + + # update epoch accumulates + if loss is not None: + epoch_loss1 += loss + if head_epoch_losses_pifpaf is None: + head_epoch_losses_pifpaf = [0.0 for _ in head_losses] + for i, head_loss in enumerate(head_losses): + if head_loss is None: + continue + head_epoch_losses_pifpaf[i] += head_loss + + # cannot cycle when validating + for data2, target2, _ in scenes[1]: + loss, head_losses = self.val_batch(data2, target2, head="crm") + + # update epoch accumulates + if loss is not None: + epoch_loss2 += loss + if head_epoch_losses_crm is None: + head_epoch_losses_crm = [0.0 for _ in head_losses] + for i, head_loss in enumerate(head_losses): + if head_loss is None: + continue + head_epoch_losses_crm[i] += head_loss + + eval_time = time.time() - start_time + + self.log.info({ + 'type': 'val-epoch', + 'epoch': epoch, + 'loss': round(epoch_loss1/len(scenes[0]) + epoch_loss2/len(scenes[1]), 5), + 'head_losses': [round(l / len(scenes[0]), 5) for l in head_epoch_losses_pifpaf] + [round(l / len(scenes[1]), 5) for l in head_epoch_losses_crm], + 'time': round(eval_time, 1), + }) + + def write_model(self, epoch, final=True): + self.model.cpu() + + if isinstance(self.model, torch.nn.DataParallel): + self.log.debug('Writing a dataparallel model.') + model = self.model.module + else: + self.log.debug('Writing a single-thread model.') + model = self.model + + filename = '{}.epoch{:03d}'.format(self.out, epoch) + self.log.debug('about to write model') + torch.save({ + 'model': model, + 'epoch': epoch, + 'meta': self.model_meta_data, + }, filename) + self.log.debug('model written') + + if final: + sha256_hash = hashlib.sha256() + with open(filename, 'rb') as f: + for byte_block in iter(lambda: f.read(8192), b''): + sha256_hash.update(byte_block) + file_hash = sha256_hash.hexdigest() + outname, _, outext = self.out.rpartition('.') + final_filename = '{}-{}.{}'.format(outname, file_hash[:8], outext) + shutil.copyfile(filename, final_filename) + + self.model.to(self.device)