diff --git a/.gitignore b/.gitignore index 0205d62..5aeb0ff 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.pyc .DS_Store +.idea/ diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 3e1dba4..90891b8 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Progressive growing of GANs +# Face Aging with Progressive growing of GANs PyTorch implementation of [Progressive Growing of GANs for Improved Quality, Stability, and Variation](http://arxiv.org/abs/1710.10196). ## How to create CelebA-HQ dataset @@ -6,7 +6,10 @@ I borrowed `h5tool.py` from [official code](https://github.com/tkarras/progressi ``` python2 h5tool.py create_celeba_hq file_name_to_save /path/to/celeba_dataset/ /path/to/celeba_hq_deltas ``` - +This is what I used on my laptop +``` +python2 h5tool.py create_celeba_hq /Users/yuan/Downloads/CelebA-HQ /Users/yuan/Downloads/CelebA/Original\ CelebA/ /Users/yuan/Downloads/CelebA/CelebA-HQ-Deltas +``` I found that MD5 checking were always failed, so I just commented out the MD5 checking part([LN 568](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/h5tool#L568) and [LN 589](https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/h5tool#L589)) With default setting, it took 1 day on my server. You can specific `num_threads` and `num_tasks` for accleration. @@ -16,7 +19,14 @@ You have to create CelebA-HQ dataset first, please follow the instructions above To obtain the similar results in `samples` directory, see `train_no_tanh.py` or `train.py` scipt for details(with default options). Both should work well. For example, you could run ``` -python train.py --gpu 0 --train_kimg 600 --transition_kimg 600 --lr 1e-3 --beta1 0 --beta2 0.99 --gan lsgan --first_resol 4 --target_resol 256 --no_tanh +conda create -n pytorch_p36 python=3.6 h5py matplotlib +source activate pytorch_p36 +conda install pytorch torchvision -c pytorch +conda install scipy +pip install tensorflow + +#0=first gpu, 1=2nd gpu ,2=3rd gpu etc... +python train.py --gpu 0,1,2 --train_kimg 600 --transition_kimg 600 --beta1 0 --beta2 0.99 --gan lsgan --first_resol 4 --target_resol 256 --no_tanh ``` `train_kimg`(`transition_kimg`) means after seeing `train_kimg * 1000`(`transition_kimg * 1000`) real images, switching to fade in(stabilize) phase. Currently only support LSGAN and GAN with `--no_noise` option, since WGAN-GP is unavailable, `--drift` option does not affect the result. `--no_tanh` means do not use `tanh` at generator's output layer. @@ -26,6 +36,11 @@ If you are Python 2 user, You'd better add this to the top of `train.py` since I from __future__ import print_function ``` + +Tensorboard +``` +tensorboard --logdir='./logs' +``` ## Update history * **Update(20171213)**: Update `data.py`, now when fading in, real images are weighted combination of current resolution images and 0.5x resolution images. This weighting trick is similar to the one used in Generator's outputs or Discriminator's inputs. This helps stabilize when fading in. @@ -73,6 +88,7 @@ from __future__ import print_function * **Update(20171111)**: It's still under implementation. I did not care design the structure, and now I had to reimplement(phase='fade in' is hard to implement under current structure). I also fixed some bugs, since reimplementation is needed, I do not plan to pull requests at this moment. -# Official implementation -Official implementation using lasagne can ben found at [tkarras/progressive_growing_of_gans](https://github.com/tkarras/progressive_growing_of_gans). +# Reference implementation +* https://github.com/github-pengge/PyTorch-progressive_growing_of_gans + diff --git a/began.py b/began.py old mode 100644 new mode 100755 diff --git a/debug.py b/debug.py old mode 100644 new mode 100755 index fb3fea7..1d60115 --- a/debug.py +++ b/debug.py @@ -3,8 +3,8 @@ import sys sys.path.append('./models') sys.path.append('./utils') -from model import * -from data import CelebA +from models.model import * +from utils.data import CelebA G = Generator(num_channels=3, resolution=1024, fmap_max=512, fmap_base=8192, latent_size=512) diff --git a/h5tool.py b/h5tool.py old mode 100644 new mode 100755 diff --git a/models/base_model.py b/models/base_model.py old mode 100644 new mode 100755 index ea37eab..521df90 --- a/models/base_model.py +++ b/models/base_model.py @@ -1,459 +1,481 @@ -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn -from torch.autograd import Variable -from torch.nn.parameter import Parameter -from torch.nn import functional as F -from torch.nn.init import kaiming_normal, calculate_gain -import numpy as np -import sys -if sys.version_info.major == 3: - from functools import reduce - -DEBUG = False - - -class PixelNormLayer(nn.Module): - def __init__(self, eps=1e-8): - super(PixelNormLayer, self).__init__() - self.eps = eps - - def forward(self, x): - return x / (torch.mean(x**2, dim=1, keepdim=True) + 1e-8) ** 0.5 - - def __repr__(self): - return self.__class__.__name__ + '(eps = %s)' % (self.eps) - - -class WScaleLayer(nn.Module): - def __init__(self, incoming): - super(WScaleLayer, self).__init__() - self.incoming = incoming - self.scale = (torch.mean(self.incoming.weight.data ** 2)) ** 0.5 - self.incoming.weight.data.copy_(self.incoming.weight.data / self.scale) - self.bias = None - if self.incoming.bias is not None: - self.bias = self.incoming.bias - self.incoming.bias = None - - def forward(self, x): - x = self.scale * x - if self.bias is not None: - x += self.bias.view(1, self.bias.size()[0], 1, 1) - return x - - def __repr__(self): - param_str = '(incoming = %s)' % (self.incoming.__class__.__name__) - return self.__class__.__name__ + param_str - - -def mean(tensor, axis, **kwargs): - if isinstance(axis, int): - axis = [axis] - for ax in axis: - tensor = torch.mean(tensor, axis=ax, **kwargs) - return tensor - - -class MinibatchStatConcatLayer(nn.Module): - def __init__(self, averaging='all'): - super(MinibatchStatConcatLayer, self).__init__() - self.averaging = averaging.lower() - if 'group' in self.averaging: - self.n = int(self.averaging[5:]) - else: - assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging - self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8) - - def forward(self, x): - shape = list(x.size()) - target_shape = shape.copy() - vals = self.adjusted_std(x, dim=0, keepdim=True) - if self.averaging == 'all': - target_shape[1] = 1 - vals = torch.mean(vals, keepdim=True) - elif self.averaging == 'spatial': - if len(shape) == 4: - vals = mean(vals, axis=[2,3], keepdim=True) # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True) - elif self.averaging == 'none': - target_shape = [target_shape[0]] + [s for s in target_shape[1:]] - elif self.averaging == 'gpool': - if len(shape) == 4: - vals = mean(x, [0,2,3], keepdim=True) # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True) - elif self.averaging == 'flat': - target_shape[1] = 1 - vals = torch.FloatTensor([self.adjusted_std(x)]) - else: # self.averaging == 'group' - target_shape[1] = self.n - vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3]) - vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1) - vals = vals.expand(*target_shape) - return torch.cat([x, vals], 1) - - def __repr__(self): - return self.__class__.__name__ + '(averaging = %s)' % (self.averaging) - - -class MinibatchDiscriminationLayer(nn.Module): - def __init__(self, num_kernels): - super(MinibatchDiscriminationLayer, self).__init__() - self.num_kernels = num_kernels - - def forward(self, x): - pass - - - -class GDropLayer(nn.Module): - def __init__(self, mode='mul', strength=0.2, axes=(0,1), normalize=False): - super(GDropLayer, self).__init__() - self.mode = mode.lower() - assert self.mode in ['mul', 'drop', 'prop'], 'Invalid GDropLayer mode'%mode - self.strength = strength - self.axes = [axes] if isinstance(axes, int) else list(axes) - self.normalize = normalize - self.gain = None - - def forward(self, x, deterministic=False): - if deterministic or not self.strength: - return x - - rnd_shape = [s if axis in self.axes else 1 for axis, s in enumerate(x.size())] # [x.size(axis) for axis in self.axes] - if self.mode == 'drop': - p = 1 - self.strength - rnd = np.random.binomial(1, p=p, size=rnd_shape) / p - elif self.mode == 'mul': - rnd = (1 + self.strength) ** np.random.normal(size=rnd_shape) - else: - coef = self.strength * x.size(1) ** 0.5 - rnd = np.random.normal(size=rnd_shape) * coef + 1 - - if self.normalize: - rnd = rnd / np.linalg.norm(rnd, keepdims=True) - rnd = Variable(torch.from_numpy(rnd).type(x.data.type())) - if x.is_cuda: - rnd = rnd.cuda() - return x * rnd - - def __repr__(self): - param_str = '(mode = %s, strength = %s, axes = %s, normalize = %s)' % (self.mode, self.strength, self.axes, self.normalize) - return self.__class__.__name__ + param_str - - -class LayerNormLayer(nn.Module): - def __init__(self, incoming, eps=1e-4): - super(LayerNormLayer, self).__init__() - self.incoming = incoming - self.eps = eps - self.gain = Parameter(torch.FloatTensor([1.0]), requires_grad=True) - self.bias = None - - if self.incoming.bias is not None: - self.bias = self.incoming.bias - self.incoming.bias = None - - def forward(self, x): - x = x - mean(x, axis=range(1, len(x.size()))) - x = x * 1.0/(torch.sqrt(mean(x**2, axis=range(1, len(x.size())), keepdim=True) + self.eps)) - x = x * self.gain - if self.bias is not None: - x += self.bias - return x - - def __repr__(self): - param_str = '(incoming = %s, eps = %s)' % (self.incoming.__class__.__name__, self.eps) - return self.__class__.__name__ + param_str - - -def resize_activations(v, so): - si = list(v.size()) - so = list(so) - assert len(si) == len(so) and si[0] == so[0] - - # Decrease feature maps. - if si[1] > so[1]: - v = v[:, :so[1]] - - # Shrink spatial axes. - if len(si) == 4 and (si[2] > so[2] or si[3] > so[3]): - assert si[2] % so[2] == 0 and si[3] % so[3] == 0 - ks = (si[2] // so[2], si[3] // so[3]) - v = F.avg_pool2d(v, kernel_size=ks, stride=ks, ceil_mode=False, padding=0, count_include_pad=False) - - # Extend spatial axes. Below is a wrong implementation - # shape = [1, 1] - # for i in range(2, len(si)): - # if si[i] < so[i]: - # assert so[i] % si[i] == 0 - # shape += [so[i] // si[i]] - # else: - # shape += [1] - # v = v.repeat(*shape) - if si[2] < so[2]: - assert so[2] % si[2] == 0 and so[2] / si[2] == so[3] / si[3] # currently only support this case - v = F.upsample(v, scale_factor=so[2]//si[2], mode='nearest') - - # Increase feature maps. - if si[1] < so[1]: - z = torch.zeros((v.shape[0], so[1] - si[1]) + so[2:]) - v = torch.cat([v, z], 1) - return v - - -class GSelectLayer(nn.Module): - def __init__(self, pre, chain, post): - super(GSelectLayer, self).__init__() - assert len(chain) == len(post) - self.pre = pre - self.chain = chain - self.post = post - self.N = len(self.chain) - - def forward(self, x, y=None, cur_level=None, insert_y_at=None): - if cur_level is None: - cur_level = self.N # cur_level: physical index - if y is not None: - assert insert_y_at is not None - - min_level, max_level = int(np.floor(cur_level-1)), int(np.ceil(cur_level-1)) - min_level_weight, max_level_weight = int(cur_level+1)-cur_level, cur_level-int(cur_level) - - _from, _to, _step = 0, max_level+1, 1 - - if self.pre is not None: - x = self.pre(x) - - out = {} - if DEBUG: - print('G: level=%s, size=%s' % ('in', x.size())) - for level in range(_from, _to, _step): - if level == insert_y_at: - x = self.chain[level](x, y) - else: - x = self.chain[level](x) - - if DEBUG: - print('G: level=%d, size=%s' % (level, x.size())) - - if level == min_level: - out['min_level'] = self.post[level](x) - if level == max_level: - out['max_level'] = self.post[level](x) - x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ - out['max_level'] * max_level_weight - if DEBUG: - print('G:', x.size()) - return x - - -class DSelectLayer(nn.Module): - def __init__(self, pre, chain, inputs): - super(DSelectLayer, self).__init__() - assert len(chain) == len(inputs) - self.pre = pre - self.chain = chain - self.inputs = inputs - self.N = len(self.chain) - - def forward(self, x, y=None, cur_level=None, insert_y_at=None): - if cur_level is None: - cur_level = self.N # cur_level: physical index - if y is not None: - assert insert_y_at is not None - - max_level, min_level = int(np.floor(self.N-cur_level)), int(np.ceil(self.N-cur_level)) - min_level_weight, max_level_weight = int(cur_level+1)-cur_level, cur_level-int(cur_level) - - _from, _to, _step = min_level+1, self.N, 1 - - if self.pre is not None: - x = self.pre(x) - - if DEBUG: - print('D: level=%s, size=%s, max_level=%s, min_level=%s' % ('in', x.size(), max_level, min_level)) - - if max_level == min_level: - x = self.inputs[max_level](x) - if max_level == insert_y_at: - x = self.chain[max_level](x, y) - else: - x = self.chain[max_level](x) - else: - out = {} - tmp = self.inputs[max_level](x) - if max_level == insert_y_at: - tmp = self.chain[max_level](tmp, y) - else: - tmp = self.chain[max_level](tmp) - out['max_level'] = tmp - out['min_level'] = self.inputs[min_level](x) - x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ - out['max_level'] * max_level_weight - if min_level == insert_y_at: - x = self.chain[min_level](x, y) - else: - x = self.chain[min_level](x) - - for level in range(_from, _to, _step): - if level == insert_y_at: - x = self.chain[level](x, y) - else: - x = self.chain[level](x) - - if DEBUG: - print('D: level=%d, size=%s' % (level, x.size())) - return x - - -class AEDSelectLayer(nn.Module): - def __init__(self, pre, chain, nins): - super(AEDSelectLayer, self).__init__() - assert len(chain) == len(nins) - self.pre = pre - self.chain = chain - self.nins = nins - self.N = len(self.chain) // 2 - - def forward(self, x, cur_level=None): - if cur_level is None: - cur_level = self.N # cur_level: physical index - - max_level, min_level = int(np.floor(self.N-cur_level)), int(np.ceil(self.N-cur_level)) - min_level_weight, max_level_weight = int(cur_level+1)-cur_level, cur_level-int(cur_level) - - _from, _to, _step = min_level, self.N, 1 - - if self.pre is not None: - x = self.pre(x) - - if DEBUG: - print('D: level=%s, size=%s, max_level=%s, min_level=%s' % ('in', x.size(), max_level, min_level)) - - # encoder - if max_level == min_level: - in_max_level = 0 - else: - in_max_level = self.chain[max_level](self.nins[max_level](x)) - if DEBUG: - print('D: level=%s(max_level), size=%s, encoder' % (max_level, in_max_level.size())) - - for level in range(_from, _to, _step): - if level == min_level: - in_min_level = self.nins[level](x) - target_shape = in_max_level.size() if max_level != min_level else in_min_level.size() - x = min_level_weight * resize_activations(in_min_level, target_shape) + max_level_weight * in_max_level - x = self.chain[level](x) - - if DEBUG: - print('D: level=%s, size=%s, encoder' % (level, x.size())) - - # decoder - from_, to_, step_ = self.N, 2*self.N-min_level, 1 - for level in range(from_, to_, step_): - x = self.chain[level](x) - if level == 2*self.N-min_level-1: # min output level - out_min_level = self.nins[level](x) - - if DEBUG: - print('D: level=%s, size=%s, decoder' % (level, x.size())) - - if max_level == min_level: - out_max_level = 0 - else: - out_max_level = self.nins[2*self.N-max_level-1](self.chain[2*self.N-max_level-1](x)) - - target_shape = out_max_level.size() if max_level != min_level else out_min_level.size() - x = min_level_weight * resize_activations(out_min_level, target_shape) + max_level_weight * out_max_level - - if DEBUG: - print('D: level=%s, size=%s' % ('out', x.size())) - return x - - # if max_level == min_level: - # x = self.nins[max_level](x) - # if not min_level+1 == self.N-1: - # x = self.chain[max_level](x) - # if DEBUG: - # print('D: level=%d, size=%s' % (min_level, x.size())) - # else: - # out = {} - # tmp = self.nins[max_level](x) - # tmp = self.chain[max_level](tmp) - # out['max_level'] = tmp - # if DEBUG: - # print('D: level=%d, size=%s' % (max_level, tmp.size())) - # out['min_level'] = self.nins[min_level](x) - # x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ - # out['max_level'] * max_level_weight - # if not min_level == self.N-1: - # x = self.chain[min_level](x) - # if DEBUG: - # print('D: level=%d, size=%s' % (min_level, x.size())) - - # for level in range(_from, _to, _step): - # x = self.chain[level](x) - - # if DEBUG: - # print('D: level=%d, size=%s, encoder' % (level, x.size())) - - # for level in range(_to, _to-_from+_to, _step): - # x = self.chain[level](x) - - # if DEBUG: - # print('D: level=%d, size=%s, decoder' % (level, x.size())) - - # if min_level == max_level: - # if not min_level+1 == self.N-1: - # x = self.chain[_to-_from+_to](x) - # x = self.nins[_to-_from+_to+1](x) - # else: - # out = {} - # if not min_level+1 == self.N-1: - # tmp = self.chain[_to-_from+_to-1](x) - # else: - # tmp = x - # out['min_level'] = self.nins[_to-_from+_to](tmp) - # if DEBUG: - # print('D: level=%d, size=%s, min_level' % (_to-_from+_to, out['min_level'].size())) - # x = self.chain[_to-_from+_to](x) - # out['max_level'] = self.nins[_to-_from+_to+1](x) - # x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ - # out['max_level'] * max_level_weight - - # if DEBUG: - # print('D: size=%s' % (x.size(),)) - # return x - - -class ConcatLayer(nn.Module): - def __init__(self): - super(ConcatLayer, self).__init__() - - def forward(self, x, y): - return torch.cat([x, y], 1) - - -class ReshapeLayer(nn.Module): - def __init__(self, new_shape): - super(ReshapeLayer, self).__init__() - self.new_shape = new_shape # not include minibatch dimension - - def forward(self, x): - assert reduce(lambda u,v: u*v, self.new_shape) == reduce(lambda u,v: u*v, x.size()[1:]) - return x.view(-1, *self.new_shape) - - -def he_init(layer, nonlinearity='conv2d', param=None): - nonlinearity = nonlinearity.lower() - if nonlinearity not in ['linear', 'conv1d', 'conv2d', 'conv3d', 'relu', 'leaky_relu', 'sigmoid', 'tanh']: - if not hasattr(layer, 'gain') or layer.gain is None: - gain = 0 # default - else: - gain = layer.gain - elif nonlinearity == 'leaky_relu': - assert param is not None, 'Negative_slope(param) should be given.' - gain = calculate_gain(nonlinearity, param) - else: - gain = calculate_gain(nonlinearity) - kaiming_normal(layer.weight, a=gain) - +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter +from torch.nn import functional as F +from torch.nn.init import kaiming_normal, calculate_gain +import numpy as np +import sys +if sys.version_info.major == 3: + from functools import reduce + +DEBUG = False + + +class PixelNormLayer(nn.Module): + """ + Pixelwise feature vector normalization. + """ + def __init__(self, eps=1e-8): + super(PixelNormLayer, self).__init__() + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) + + def __repr__(self): + return self.__class__.__name__ + '(eps = %s)' % (self.eps) + + +class WScaleLayer(nn.Module): + """ + Applies equalized learning rate to the preceding layer. + """ + def __init__(self, incoming): + super(WScaleLayer, self).__init__() + self.incoming = incoming + self.scale = (torch.mean(self.incoming.weight.data ** 2)) ** 0.5 + self.incoming.weight.data.copy_(self.incoming.weight.data / self.scale) + self.bias = None + if self.incoming.bias is not None: + self.bias = self.incoming.bias + self.incoming.bias = None + + def forward(self, x): + x = self.scale * x + if self.bias is not None: + x += self.bias.view(1, self.bias.size()[0], 1, 1) + return x + + def __repr__(self): + param_str = '(incoming = %s)' % (self.incoming.__class__.__name__) + return self.__class__.__name__ + param_str + + +def mean(tensor, axis, **kwargs): + if isinstance(axis, int): + axis = [axis] + for ax in axis: + tensor = torch.mean(tensor, axis=ax, **kwargs) + return tensor + + +class MinibatchStatConcatLayer(nn.Module): + """Minibatch stat concatenation layer. + - averaging tells how much averaging to use ('all', 'spatial', 'none') + """ + def __init__(self, averaging='all'): + super(MinibatchStatConcatLayer, self).__init__() + self.averaging = averaging.lower() + if 'group' in self.averaging: + self.n = int(self.averaging[5:]) + else: + assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging + self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8) #Tstdeps in the original implementation + + def forward(self, x): + shape = list(x.size()) + target_shape = shape.copy() + vals = self.adjusted_std(x, dim=0, keepdim=True)# per activation, over minibatch dim + if self.averaging == 'all': # average everything --> 1 value per minibatch + target_shape[1] = 1 + vals = torch.mean(vals, dim=1, keepdim=True)#vals = torch.mean(vals, keepdim=True) + + elif self.averaging == 'spatial': # average spatial locations + if len(shape) == 4: + vals = mean(vals, axis=[2,3], keepdim=True) # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True) + elif self.averaging == 'none': # no averaging, pass on all information + target_shape = [target_shape[0]] + [s for s in target_shape[1:]] + elif self.averaging == 'gpool': # EXPERIMENTAL: compute variance (func) over minibatch AND spatial locations. + if len(shape) == 4: + vals = mean(x, [0,2,3], keepdim=True) # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True) + elif self.averaging == 'flat': # variance of ALL activations --> 1 value per minibatch + target_shape[1] = 1 + vals = torch.FloatTensor([self.adjusted_std(x)]) + else: # self.averaging == 'group' # average everything over n groups of feature maps --> n values per minibatch + target_shape[1] = self.n + vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3]) + vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1) + vals = vals.expand(*target_shape) + return torch.cat([x, vals], 1) # feature-map concatanation + + def __repr__(self): + return self.__class__.__name__ + '(averaging = %s)' % (self.averaging) + + +class MinibatchDiscriminationLayer(nn.Module): + def __init__(self, num_kernels): + super(MinibatchDiscriminationLayer, self).__init__() + self.num_kernels = num_kernels + + def forward(self, x): + pass + + +class GDropLayer(nn.Module): + """ + # Generalized dropout layer. Supports arbitrary subsets of axes and different + # modes. Mainly used to inject multiplicative Gaussian noise in the network. + """ + def __init__(self, mode='mul', strength=0.2, axes=(0,1), normalize=False): + super(GDropLayer, self).__init__() + self.mode = mode.lower() + assert self.mode in ['mul', 'drop', 'prop'], 'Invalid GDropLayer mode'%mode + self.strength = strength + self.axes = [axes] if isinstance(axes, int) else list(axes) + self.normalize = normalize + self.gain = None + + def forward(self, x, deterministic=False): + if deterministic or not self.strength: + return x + + rnd_shape = [s if axis in self.axes else 1 for axis, s in enumerate(x.size())] # [x.size(axis) for axis in self.axes] + if self.mode == 'drop': + p = 1 - self.strength + rnd = np.random.binomial(1, p=p, size=rnd_shape) / p + elif self.mode == 'mul': + rnd = (1 + self.strength) ** np.random.normal(size=rnd_shape) + else: + coef = self.strength * x.size(1) ** 0.5 + rnd = np.random.normal(size=rnd_shape) * coef + 1 + + if self.normalize: + rnd = rnd / np.linalg.norm(rnd, keepdims=True) + rnd = Variable(torch.from_numpy(rnd).type(x.data.type())) + if x.is_cuda: + rnd = rnd.cuda() + return x * rnd + + def __repr__(self): + param_str = '(mode = %s, strength = %s, axes = %s, normalize = %s)' % (self.mode, self.strength, self.axes, self.normalize) + return self.__class__.__name__ + param_str + + +class LayerNormLayer(nn.Module): + """ + Layer normalization. Custom reimplementation based on the paper: https://arxiv.org/abs/1607.06450 + """ + def __init__(self, incoming, eps=1e-4): + super(LayerNormLayer, self).__init__() + self.incoming = incoming + self.eps = eps + self.gain = Parameter(torch.FloatTensor([1.0]), requires_grad=True) + self.bias = None + + if self.incoming.bias is not None: + self.bias = self.incoming.bias + self.incoming.bias = None + + def forward(self, x): + x = x - mean(x, axis=range(1, len(x.size()))) + x = x * 1.0/(torch.sqrt(mean(x**2, axis=range(1, len(x.size())), keepdim=True) + self.eps)) + x = x * self.gain + if self.bias is not None: + x += self.bias + return x + + def __repr__(self): + param_str = '(incoming = %s, eps = %s)' % (self.incoming.__class__.__name__, self.eps) + return self.__class__.__name__ + param_str + + +def resize_activations(v, so): + """ + Resize activation tensor 'v' of shape 'si' to match shape 'so'. + :param v: + :param so: + :return: + """ + si = list(v.size()) + so = list(so) + assert len(si) == len(so) and si[0] == so[0] + + # Decrease feature maps. + if si[1] > so[1]: + v = v[:, :so[1]] + + # Shrink spatial axes. + if len(si) == 4 and (si[2] > so[2] or si[3] > so[3]): + assert si[2] % so[2] == 0 and si[3] % so[3] == 0 + ks = (si[2] // so[2], si[3] // so[3]) + v = F.avg_pool2d(v, kernel_size=ks, stride=ks, ceil_mode=False, padding=0, count_include_pad=False) + + # Extend spatial axes. Below is a wrong implementation + # shape = [1, 1] + # for i in range(2, len(si)): + # if si[i] < so[i]: + # assert so[i] % si[i] == 0 + # shape += [so[i] // si[i]] + # else: + # shape += [1] + # v = v.repeat(*shape) + if si[2] < so[2]: + assert so[2] % si[2] == 0 and so[2] / si[2] == so[3] / si[3] # currently only support this case + v = F.upsample(v, scale_factor=so[2]//si[2], mode='nearest') + + # Increase feature maps. + if si[1] < so[1]: + z = torch.zeros((v.shape[0], so[1] - si[1]) + so[2:]) + v = torch.cat([v, z], 1) + return v + + +class GSelectLayer(nn.Module): + def __init__(self, pre, chain, post): + super(GSelectLayer, self).__init__() + assert len(chain) == len(post) + self.pre = pre + self.chain = chain + self.post = post + self.N = len(self.chain) + + def forward(self, x, y=None, cur_level=None, insert_y_at=None): + if cur_level is None: + cur_level = self.N # cur_level: physical index + if y is not None: + assert insert_y_at is not None + + min_level, max_level = int(np.floor(cur_level-1)), int(np.ceil(cur_level-1)) + min_level_weight, max_level_weight = int(cur_level+1)-cur_level, cur_level-int(cur_level) + + _from, _to, _step = 0, max_level+1, 1 + + if self.pre is not None: + x = self.pre(x) + + out = {} + if DEBUG: + print('G: level=%s, size=%s' % ('in', x.size())) + for level in range(_from, _to, _step): + if level == insert_y_at: + x = self.chain[level](x, y) + else: + x = self.chain[level](x) + + if DEBUG: + print('G: level=%d, size=%s' % (level, x.size())) + + if level == min_level: + out['min_level'] = self.post[level](x) + if level == max_level: + out['max_level'] = self.post[level](x) + x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ + out['max_level'] * max_level_weight + if DEBUG: + print('G:', x.size()) + return x + + +class DSelectLayer(nn.Module): + def __init__(self, pre, chain, inputs): + super(DSelectLayer, self).__init__() + assert len(chain) == len(inputs) + self.pre = pre + self.chain = chain + self.inputs = inputs + self.N = len(self.chain) + + def forward(self, x, y=None, cur_level=None, insert_y_at=None): + if cur_level is None: + cur_level = self.N # cur_level: physical index + if y is not None: + assert insert_y_at is not None + + max_level, min_level = int(np.floor(self.N-cur_level)), int(np.ceil(self.N-cur_level)) + min_level_weight, max_level_weight = int(cur_level+1)-cur_level, cur_level-int(cur_level) + + _from, _to, _step = min_level+1, self.N, 1 + + if self.pre is not None: + x = self.pre(x) + + if DEBUG: + print('D: level=%s, size=%s, max_level=%s, min_level=%s' % ('in', x.size(), max_level, min_level)) + + if max_level == min_level: + x = self.inputs[max_level](x) + if max_level == insert_y_at: + x = self.chain[max_level](x, y) + else: + x = self.chain[max_level](x) + else: + out = {} + tmp = self.inputs[max_level](x) + if max_level == insert_y_at: + tmp = self.chain[max_level](tmp, y) + else: + tmp = self.chain[max_level](tmp) + out['max_level'] = tmp + out['min_level'] = self.inputs[min_level](x) + x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ + out['max_level'] * max_level_weight + if min_level == insert_y_at: + x = self.chain[min_level](x, y) + else: + x = self.chain[min_level](x) + + for level in range(_from, _to, _step): + if level == insert_y_at: + x = self.chain[level](x, y) + else: + x = self.chain[level](x) + + if DEBUG: + print('D: level=%d, size=%s' % (level, x.size())) + return x + + +class AEDSelectLayer(nn.Module): + def __init__(self, pre, chain, nins): + super(AEDSelectLayer, self).__init__() + assert len(chain) == len(nins) + self.pre = pre + self.chain = chain + self.nins = nins + self.N = len(self.chain) // 2 + + def forward(self, x, cur_level=None): + if cur_level is None: + cur_level = self.N # cur_level: physical index + + max_level, min_level = int(np.floor(self.N-cur_level)), int(np.ceil(self.N-cur_level)) + min_level_weight, max_level_weight = int(cur_level+1)-cur_level, cur_level-int(cur_level) + + _from, _to, _step = min_level, self.N, 1 + + if self.pre is not None: + x = self.pre(x) + + if DEBUG: + print('D: level=%s, size=%s, max_level=%s, min_level=%s' % ('in', x.size(), max_level, min_level)) + + # encoder + if max_level == min_level: + in_max_level = 0 + else: + in_max_level = self.chain[max_level](self.nins[max_level](x)) + if DEBUG: + print('D: level=%s(max_level), size=%s, encoder' % (max_level, in_max_level.size())) + + for level in range(_from, _to, _step): + if level == min_level: + in_min_level = self.nins[level](x) + target_shape = in_max_level.size() if max_level != min_level else in_min_level.size() + x = min_level_weight * resize_activations(in_min_level, target_shape) + max_level_weight * in_max_level + x = self.chain[level](x) + + if DEBUG: + print('D: level=%s, size=%s, encoder' % (level, x.size())) + + # decoder + from_, to_, step_ = self.N, 2*self.N-min_level, 1 + for level in range(from_, to_, step_): + x = self.chain[level](x) + if level == 2*self.N-min_level-1: # min output level + out_min_level = self.nins[level](x) + + if DEBUG: + print('D: level=%s, size=%s, decoder' % (level, x.size())) + + if max_level == min_level: + out_max_level = 0 + else: + out_max_level = self.nins[2*self.N-max_level-1](self.chain[2*self.N-max_level-1](x)) + + target_shape = out_max_level.size() if max_level != min_level else out_min_level.size() + x = min_level_weight * resize_activations(out_min_level, target_shape) + max_level_weight * out_max_level + + if DEBUG: + print('D: level=%s, size=%s' % ('out', x.size())) + return x + + # if max_level == min_level: + # x = self.nins[max_level](x) + # if not min_level+1 == self.N-1: + # x = self.chain[max_level](x) + # if DEBUG: + # print('D: level=%d, size=%s' % (min_level, x.size())) + # else: + # out = {} + # tmp = self.nins[max_level](x) + # tmp = self.chain[max_level](tmp) + # out['max_level'] = tmp + # if DEBUG: + # print('D: level=%d, size=%s' % (max_level, tmp.size())) + # out['min_level'] = self.nins[min_level](x) + # x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ + # out['max_level'] * max_level_weight + # if not min_level == self.N-1: + # x = self.chain[min_level](x) + # if DEBUG: + # print('D: level=%d, size=%s' % (min_level, x.size())) + + # for level in range(_from, _to, _step): + # x = self.chain[level](x) + + # if DEBUG: + # print('D: level=%d, size=%s, encoder' % (level, x.size())) + + # for level in range(_to, _to-_from+_to, _step): + # x = self.chain[level](x) + + # if DEBUG: + # print('D: level=%d, size=%s, decoder' % (level, x.size())) + + # if min_level == max_level: + # if not min_level+1 == self.N-1: + # x = self.chain[_to-_from+_to](x) + # x = self.nins[_to-_from+_to+1](x) + # else: + # out = {} + # if not min_level+1 == self.N-1: + # tmp = self.chain[_to-_from+_to-1](x) + # else: + # tmp = x + # out['min_level'] = self.nins[_to-_from+_to](tmp) + # if DEBUG: + # print('D: level=%d, size=%s, min_level' % (_to-_from+_to, out['min_level'].size())) + # x = self.chain[_to-_from+_to](x) + # out['max_level'] = self.nins[_to-_from+_to+1](x) + # x = resize_activations(out['min_level'], out['max_level'].size()) * min_level_weight + \ + # out['max_level'] * max_level_weight + + # if DEBUG: + # print('D: size=%s' % (x.size(),)) + # return x + + +class ConcatLayer(nn.Module): + def __init__(self): + super(ConcatLayer, self).__init__() + + def forward(self, x, y): + return torch.cat([x, y], 1) + + +class ReshapeLayer(nn.Module): + def __init__(self, new_shape): + super(ReshapeLayer, self).__init__() + self.new_shape = new_shape # not include minibatch dimension + + def forward(self, x): + assert reduce(lambda u,v: u*v, self.new_shape) == reduce(lambda u,v: u*v, x.size()[1:]) + return x.view(-1, *self.new_shape) + + +def he_init(layer, nonlinearity='conv2d', param=None): + nonlinearity = nonlinearity.lower() + if nonlinearity not in ['linear', 'conv1d', 'conv2d', 'conv3d', 'relu', 'leaky_relu', 'sigmoid', 'tanh']: + if not hasattr(layer, 'gain') or layer.gain is None: + gain = 0 # default + else: + gain = layer.gain + elif nonlinearity == 'leaky_relu': + assert param is not None, 'Negative_slope(param) should be given.' + gain = calculate_gain(nonlinearity, param) + else: + gain = calculate_gain(nonlinearity) + kaiming_normal(layer.weight, a=gain) + diff --git a/models/model.py b/models/model.py old mode 100644 new mode 100755 index 2c111ba..b8ef295 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from base_model import * +from models.base_model import * def G_conv(incoming, in_channels, out_channels, kernel_size, padding, nonlinearity, init, param=None, @@ -165,7 +165,9 @@ def __init__(self, negative_slope = 0.2 act = nn.LeakyReLU(negative_slope=negative_slope) + # input activation iact = 'leaky_relu' + # output activation output_act = nn.Sigmoid() if self.sigmoid_at_end else 'linear' output_iact = 'sigmoid' if self.sigmoid_at_end else 'linear' gdrop_param = {'mode': 'prop', 'strength': gdrop_strength} @@ -199,6 +201,7 @@ def __init__(self, net = D_conv(net, oc, self.get_nf(0), 4, 0, act, iact, negative_slope, False, self.use_wscale, self.use_gdrop, self.use_layernorm, gdrop_param) + # Increasing Variation Using MINIBATCH Standard Deviation if self.mbdisc_kernels: net += [MinibatchDiscriminationLayer(num_kernels=self.mbdisc_kernels)] diff --git a/models/test.py b/models/test.py old mode 100644 new mode 100755 index c0007cc..b64d4b1 --- a/models/test.py +++ b/models/test.py @@ -1,6 +1,6 @@ import sys sys.path.append('models') -from model import AutoencodingDiscriminator +from models import AutoencodingDiscriminator import torch from torch.autograd import Variable D = AutoencodingDiscriminator(3, 32) diff --git a/train.py b/train.py old mode 100644 new mode 100755 index 46d7c45..84516c6 --- a/train.py +++ b/train.py @@ -6,12 +6,12 @@ import sys, os, time sys.path.append('utils') sys.path.append('models') -from data import CelebA, RandomNoiseGenerator -from model import Generator, Discriminator +from utils.data import CelebA, RandomNoiseGenerator +from models.model import Generator, Discriminator import argparse import numpy as np from scipy.misc import imsave - +from utils.logger import Logger class PGGAN(): def __init__(self, G, D, data, noise, opts): @@ -20,12 +20,13 @@ def __init__(self, G, D, data, noise, opts): self.data = data self.noise = noise self.opts = opts - + self.current_time = time.strftime('%Y-%m-%d %H%M%S') + self.logger = Logger('./logs/' + self.current_time + "/") gpu = self.opts['gpu'] self.use_cuda = len(gpu) > 0 os.environ['CUDA_VISIBLE_DEVICES'] = gpu - self.bs_map = {2**R: self.get_bs(2**R) for R in range(2, 11)} + self.bs_map = {2**R: self.get_bs(2**R) for R in range(2, 11)} # batch size map keyed by resolution_level self.rows_map = {32: 8, 16: 4, 8: 4, 4: 2, 2: 2} self.restore_model() @@ -79,7 +80,7 @@ def get_bs(self, resolution): bs = 8 / 2**(min(2, R-7)) return int(bs) - def registe_on_gpu(self): + def register_on_gpu(self): if self.use_cuda: self.G.cuda() self.D.cuda() @@ -119,10 +120,13 @@ def compute_G_loss(self): return g_adv_loss + g_add_loss def compute_D_loss(self): - d_adv_loss = self.compute_adv_loss(self.d_real, True, 0.5) + self.compute_adv_loss(self.d_fake, False, 0.5)*self.opts['fake_weight'] + self.d_adv_loss_real = self.compute_adv_loss(self.d_real, True, 0.5) + self.d_adv_loss_fake = self.compute_adv_loss(self.d_fake, False, 0.5) * self.opts['fake_weight'] + d_adv_loss = self.d_adv_loss_real + self.d_adv_loss_fake d_add_loss = self.compute_additional_d_loss() self.d_adv_loss = self._get_data(d_adv_loss) self.d_add_loss = self._get_data(d_add_loss) + return d_adv_loss + d_add_loss def _rampup(self, epoch, rampup_length): @@ -139,6 +143,8 @@ def _rampdown_linear(self, epoch, num_epochs, rampdown_length): else: return 1.0 + '''Update Learning rate + ''' def update_lr(self, cur_nimg): for param_group in self.optim_G.param_groups: lrate_coef = self._rampup(cur_nimg / 1000.0, self.opts['rampup_kimg']) @@ -154,11 +160,16 @@ def postprocess(self): pass def _numpy2var(self, x): - var = Variable(torch.from_numpy(x)) + var = Variable(torch.from_numpy(x)) if self.use_cuda: var = var.cuda() return var + def _var2numpy(self, var): + if self.use_cuda: + return var.cpu().data.numpy() + return var.data.numpy() + # def add_noise(self, x): # # TODO: support more method of adding noise. # if self.opts.get('no_noise', False): @@ -216,6 +227,40 @@ def report(self, it, num_it, phase, resol): values = (it, num_it, phase, resol, self.g_loss, self.d_loss, self.g_adv_loss, self.g_add_loss, self.d_adv_loss, self.d_add_loss) print(formation % values) + def tensorboard(self, it, num_it, phase, resol, samples): + # (1) Log the scalar values + prefix = str(resol)+'/'+phase+'/' + info = {prefix + 'G_loss': self.g_loss, + prefix + 'G_adv_loss': self.g_adv_loss, + prefix + 'G_add_loss': self.g_add_loss, + prefix + 'D_loss': self.d_loss, + prefix + 'D_adv_loss': self.d_adv_loss, + prefix + 'D_add_loss': self.d_add_loss, + prefix + 'D_adv_loss_fake': self._get_data(self.d_adv_loss_fake), + prefix + 'D_adv_loss_real': self._get_data(self.d_adv_loss_real)} + + for tag, value in info.items(): + self.logger.scalar_summary(tag, value, it) + + # (2) Log values and gradients of the parameters (histogram) + for tag, value in self.G.named_parameters(): + tag = tag.replace('.', '/') + self.logger.histo_summary('G/' + prefix +tag, self._var2numpy(value), it) + if value.grad is not None: + self.logger.histo_summary('G/' + prefix +tag + '/grad', self._var2numpy(value.grad), it) + + for tag, value in self.D.named_parameters(): + tag = tag.replace('.', '/') + self.logger.histo_summary('D/' + prefix + tag, self._var2numpy(value), it) + if value.grad is not None: + self.logger.histo_summary('D/' + prefix + tag + '/grad', + self._var2numpy(value.grad), it) + + # (3) Log the images + # info = {'images': samples[:10]} + # for tag, images in info.items(): + # logger.image_summary(tag, images, it) + def train_phase(self, R, phase, batch_size, cur_nimg, from_it, total_it): assert total_it >= from_it resol = 2 ** (R+1) @@ -231,30 +276,37 @@ def train_phase(self, R, phase, batch_size, cur_nimg, from_it, total_it): z = self.noise(batch_size) x = self.data(batch_size, cur_resol, cur_level) - # preprocess + # ===preprocess=== self.preprocess(z, x) self.update_lr(cur_nimg) - # update D + # ===update D=== self.optim_D.zero_grad() self.forward_D(cur_level, detach=True) self.backward_D() - # update G + # ===update G=== self.optim_G.zero_grad() self.forward_G(cur_level) self.backward_G() - # report + # ===report === self.report(it, total_it, phase, cur_resol) cur_nimg += batch_size - # sampling + # ===generate sample images=== + samples = [] if (it % self.opts['sample_freq'] == 0) or it == total_it-1: - self.sample(os.path.join(self.opts['sample_dir'], '%dx%d-%s-%s.png' % (cur_resol, cur_resol, phase, str(it).zfill(6)))) + samples = self.sample() + imsave(os.path.join(self.opts['sample_dir'], + '%dx%d-%s-%s.png' % (cur_resol, cur_resol, phase, str(it).zfill(6))), samples) + + # ===tensorboard visualization=== + if (it % self.opts['sample_freq'] == 0) or it == total_it - 1: + self.tensorboard(it, total_it, phase, cur_resol, samples) - # save model + # ===save model=== if (it % self.opts['save_freq'] == 0 and it > 0) or it == total_it-1: self.save(os.path.join(self.opts['ckpt_dir'], '%dx%d-%s-%s' % (cur_resol, cur_resol, phase, str(it).zfill(6)))) @@ -262,7 +314,7 @@ def train(self): # prepare self.create_optimizer() self.create_criterion() - self.registe_on_gpu() + self.register_on_gpu() to_level = int(np.log2(self.opts['target_resol'])) from_level = int(np.log2(self._from_resol)) @@ -270,7 +322,7 @@ def train(self): train_kimg = int(self.opts['train_kimg'] * 1000) transition_kimg = int(self.opts['transition_kimg'] * 1000) - + for R in range(from_level-1, to_level): batch_size = self.bs_map[2 ** (R+1)] @@ -285,7 +337,7 @@ def train(self): _range = phases[phase] self.train_phase(R, phase, batch_size, _range[0]*batch_size, _range[0], _range[1]) - def sample(self, file_name): + def sample(self): batch_size = self.z.size(0) n_row = self.rows_map[batch_size] n_col = int(np.ceil(batch_size / float(n_row))) @@ -305,12 +357,11 @@ def sample(self, file_name): samples = np.concatenate(samples, axis=1).transpose([1, 2, 0]) half = samples.shape[1] // 2 - samples[:,:half,:] = samples[:,:half,:] - np.min(samples[:,:half,:]) - samples[:,:half,:] = samples[:,:half,:] / np.max(samples[:,:half,:]) - samples[:,half:,:] = samples[:,half:,:] - np.min(samples[:,half:,:]) - samples[:,half:,:] = samples[:,half:,:] / np.max(samples[:,half:,:]) - - imsave(file_name, samples) + samples[:, :half, :] = samples[:, :half, :] - np.min(samples[:, :half, :]) + samples[:, :half, :] = samples[:, :half, :] / np.max(samples[:, :half, :]) + samples[:, half:, :] = samples[:, half:, :] - np.min(samples[:, half:, :]) + samples[:, half:, :] = samples[:, half:, :] / np.max(samples[:, half:, :]) + return samples def save(self, file_name): g_file = file_name + '-G.pth' @@ -326,8 +377,8 @@ def save(self, file_name): parser.add_argument('--total_kimg', default=10000, type=float, help='total_kimg: a param to compute lr.') parser.add_argument('--rampup_kimg', default=10000, type=float, help='rampup_kimg.') parser.add_argument('--rampdown_kimg', default=10000, type=float, help='rampdown_kimg.') - parser.add_argument('--g_lr_max', default=1e-3, type=float, help='learning rate') - parser.add_argument('--d_lr_max', default=1e-3, type=float, help='learning rate') + parser.add_argument('--g_lr_max', default=1e-3, type=float, help='Generator learning rate') + parser.add_argument('--d_lr_max', default=1e-3, type=float, help='Discriminator learning rate') parser.add_argument('--fake_weight', default=0.1, type=float, help="weight of fake images' loss of D") parser.add_argument('--beta1', default=0, type=float, help='beta1 for adam') parser.add_argument('--beta2', default=0.99, type=float, help='beta2 for adam') @@ -335,11 +386,12 @@ def save(self, file_name): parser.add_argument('--first_resol', default=4, type=int, help='first resolution') parser.add_argument('--target_resol', default=256, type=int, help='target resolution') parser.add_argument('--drift', default=1e-3, type=float, help='drift, only available for wgan_gp.') + parser.add_argument('--mbstat_avg', default='all', type=str, help='MinibatchStatConcatLayer averaging strategy (Which dimensions to average the statistic over?)') parser.add_argument('--sample_freq', default=500, type=int, help='sampling frequency.') parser.add_argument('--save_freq', default=5000, type=int, help='save model frequency.') parser.add_argument('--exp_dir', default='./exp', type=str, help='experiment dir.') parser.add_argument('--no_noise', action='store_true', help='do not add noise to real data.') - parser.add_argument('--no_tanh', action='store_true', help='do not add noise to real data.') + parser.add_argument('--no_tanh', action='store_true', help='do not use tanh in the last layer of the generator.') parser.add_argument('--restore_dir', default='', type=str, help='restore from which exp dir.') parser.add_argument('--which_file', default='', type=str, help='restore from which file, e.g. 128x128-fade_in-105000.') @@ -348,7 +400,9 @@ def save(self, file_name): args = parser.parse_args() opts = {k:v for k,v in args._get_kwargs()} + # Dimensionality of the latent vector. latent_size = 512 + # Use sigmoid activation for the last layer? sigmoid_at_end = args.gan in ['lsgan', 'gan'] if hasattr(args, 'no_tanh'): tanh_at_end = False @@ -356,7 +410,7 @@ def save(self, file_name): tanh_at_end = True G = Generator(num_channels=3, latent_size=latent_size, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=tanh_at_end) - D = Discriminator(num_channels=3, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=sigmoid_at_end) + D = Discriminator(num_channels=3, mbstat_avg=args.mbstat_avg, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=sigmoid_at_end) print(G) print(D) data = CelebA() diff --git a/train_no_tanh.py b/train_no_tanh.py old mode 100644 new mode 100755 index 065b553..b915d1d --- a/train_no_tanh.py +++ b/train_no_tanh.py @@ -1,272 +1,322 @@ -# -*- coding: utf-8 -*- -import torch -import torch.optim as optim -from torch.autograd import Variable -import sys, os, time -sys.path.append('utils') -sys.path.append('models') -from data import CelebA, RandomNoiseGenerator -from model import Generator, Discriminator -import argparse -import numpy as np -from scipy.misc import imsave - - -class PGGAN(): - def __init__(self, G, D, data, noise, opts): - self.G = G - self.D = D - self.data = data - self.noise = noise - self.opts = opts - - gpu = self.opts['gpu'] - self.use_cuda = len(gpu) > 0 - os.environ['CUDA_VISIBLE_DEVICES'] = gpu - - current_time = time.strftime('%Y-%m-%d %H%M%S') - self.opts['sample_dir'] = os.path.join(os.path.join(self.opts['exp_dir'], current_time), 'samples') - self.opts['ckpt_dir'] = os.path.join(os.path.join(self.opts['exp_dir'], current_time), 'ckpts') - os.makedirs(self.opts['sample_dir']) - os.makedirs(self.opts['ckpt_dir']) - - self.bs_map = {2**R: self.get_bs(2**R) for R in range(2, 11)} - self.rows_map = {32: 8, 16: 4, 8: 4, 4: 2, 2: 2} - - # save opts - with open(os.path.join(os.path.join(self.opts['exp_dir'], current_time), 'options.txt'), 'w') as f: - for k, v in self.opts.items(): - print('%s: %s' % (k, v), file=f) - print('batch_size_map: %s' % self.bs_map, file=f) - - def get_bs(self, resolution): - R = int(np.log2(resolution)) - if R < 7: - bs = 32 / 2**(max(0, R-4)) - else: - bs = 8 / 2**(min(2, R-7)) - return int(bs) - - def registe_on_gpu(self): - if self.use_cuda: - self.G.cuda() - self.D.cuda() - - def create_optimizer(self): - self.optim_G = optim.Adam(self.G.parameters(), lr=self.opts['lr'], betas=(self.opts['beta1'], self.opts['beta2'])) - self.optim_D = optim.Adam(self.D.parameters(), lr=self.opts['lr'], betas=(self.opts['beta1'], self.opts['beta2'])) - - def create_criterion(self): - # w is for gan - if self.opts['gan'] == 'lsgan': - self.adv_criterion = lambda p,t,w: torch.mean((p-t)**2) # sigmoid is applied here - elif self.opts['gan'] == 'wgan_gp': - self.adv_criterion = lambda p,t,w: (-2*t+1) * torch.mean(p) - elif self.opts['gan'] == 'gan': - lambda p,t,w: -w*(torch.mean(t*torch.log(p+1e-8)) + torch.mean((1-t)*torch.log(1-p+1e-8))) - else: - raise ValueError('Invalid/Unsupported GAN: %s.' % self.opts['gan']) - - def compute_adv_loss(self, prediction, target, w): - return self.adv_criterion(prediction, target, w) - - def compute_additional_g_loss(self): - return 0.0 - - def compute_additional_d_loss(self): # drifting loss and gradient penalty, weighting inside this function - return 0.0 - - def _get_data(self, d): - return d.data[0] if isinstance(d, Variable) else d - - def compute_G_loss(self): - g_adv_loss = self.compute_adv_loss(self.d_fake, True, 1) - g_add_loss = self.compute_additional_g_loss() - self.g_adv_loss = self._get_data(g_adv_loss) - self.g_add_loss = self._get_data(g_add_loss) - return g_adv_loss + g_add_loss - - def compute_D_loss(self): - d_adv_loss = self.compute_adv_loss(self.d_real, True, 0.5) + self.compute_adv_loss(self.d_fake, False, 0.5) - d_add_loss = self.compute_additional_d_loss() - self.d_adv_loss = self._get_data(d_adv_loss) - self.d_add_loss = self._get_data(d_add_loss) - return d_adv_loss + d_add_loss - - def postprocess(self): - # TODO: weight cliping or others - pass - - def _numpy2var(self, x): - var = Variable(torch.from_numpy(x)) - if self.use_cuda: - var = var.cuda() - return var - - def add_noise(self, x): - # TODO: support more method of adding noise. - if self.opts.get('no_noise', False): - return x - - if hasattr(self, '_d_'): - self._d_ = self._d_ * 0.9 + torch.mean(self.d_real).data[0] * 0.1 - else: - self._d_ = 0.0 - strength = 0.2 * max(0, self._d_ - 0.5)**2 - noise = self._numpy2var(np.random.randn(*x.size()).astype(np.float32) * strength) - return x + noise - - def preprocess(self, z, real): - self.z = self._numpy2var(z) - self.real = self._numpy2var(real) - - def forward_G(self, cur_level): - self.d_fake = self.D(self.fake, cur_level=cur_level) - - def forward_D(self, cur_level, detach=True): - self.fake = self.G(self.z, cur_level=cur_level) - self.d_real = self.D(self.add_noise(self.real), cur_level=cur_level) - self.d_fake = self.D(self.fake.detach() if detach else self.fake, cur_level=cur_level) - # print('d_real', self.d_real.view(-1)) - # print('d_fake', self.d_fake.view(-1)) - # print(self.fake[0].view(-1)) - - def backward_G(self): - g_loss = self.compute_G_loss() - g_loss.backward() - self.optim_G.step() - self.g_loss = self._get_data(g_loss) - - def backward_D(self, retain_graph=False): - d_loss = self.compute_D_loss() - d_loss.backward(retain_graph=retain_graph) - self.optim_D.step() - self.d_loss = self._get_data(d_loss) - - def report(self, it, num_it, phase, resol): - formation = 'Iter[%d|%d], %s, %s, G: %.3f, D: %.3f, G_adv: %.3f, G_add: %.3f, D_adv: %.3f, D_add: %.3f' - values = (it, num_it, phase, resol, self.g_loss, self.d_loss, self.g_adv_loss, self.g_add_loss, self.d_adv_loss, self.d_add_loss) - print(formation % values) - - def train(self): - # prepare - self.create_optimizer() - self.create_criterion() - self.registe_on_gpu() - - to_level = int(np.log2(self.opts['target_resol'])) - from_level = int(np.log2(self.opts['first_resol'])) - assert 2**to_level == self.opts['target_resol'] and 2**from_level == self.opts['first_resol'] and to_level >= from_level >= 2 - cur_level = from_level - - for R in range(from_level-1, to_level-1): - batch_size = self.bs_map[2 ** (R+1)] - train_kimg = int(self.opts['train_kimg'] * 1000) - transition_kimg = int(self.opts['transition_kimg'] * 1000) - if R == to_level-1: - transition_kimg = 0 - cur_nimg = 0 - _len = len(str(train_kimg + transition_kimg)) - _num_it = (train_kimg + transition_kimg) // batch_size - for it in range(_num_it): - # determined current level: int for stabilizing and float for fading in - cur_level = R + float(max(cur_nimg-train_kimg, 0)) / transition_kimg - cur_resol = 2 ** int(np.ceil(cur_level+1)) - phase = 'stabilize' if int(cur_level) == cur_level else 'fade_in' - - # get a batch noise and real images - z = self.noise(batch_size) - x = self.data(batch_size, cur_resol, cur_level) - - # preprocess - self.preprocess(z, x) - - # update D - self.optim_D.zero_grad() - self.forward_D(cur_level, detach=True) # TODO: feed gdrop_strength - self.backward_D() - - # update G - self.optim_G.zero_grad() - self.forward_G(cur_level) - self.backward_G() - - # report - self.report(it, _num_it, phase, cur_resol) - - cur_nimg += batch_size - - # sampling - if (it % self.opts['sample_freq'] == 0) or it == _num_it-1: - self.sample(os.path.join(self.opts['sample_dir'], '%dx%d-%s-%s.png' % (cur_resol, cur_resol, phase, str(it).zfill(6)))) - - # save model - if (it % self.opts['save_freq'] == 0 and it > 0) or it == _num_it-1: - self.save(os.path.join(self.opts['ckpt_dir'], '%dx%d-%s-%s' % (cur_resol, cur_resol, phase, str(it).zfill(6)))) - - def sample(self, file_name): - batch_size = self.z.size(0) - n_row = self.rows_map[batch_size] - n_col = int(np.ceil(batch_size / float(n_row))) - samples = [] - i = j = 0 - for row in range(n_row): - one_row = [] - # fake - for col in range(n_col): - one_row.append(self.fake[i].cpu().data.numpy()) - i += 1 - # real - for col in range(n_col): - one_row.append(self.real[j].cpu().data.numpy()) - j += 1 - samples += [np.concatenate(one_row, axis=2)] - samples = np.concatenate(samples, axis=1).transpose([1, 2, 0]) - - half = samples.shape[1] // 2 - samples[:,:half,:] = samples[:,:half,:] - np.min(samples[:,:half,:]) - samples[:,:half,:] = samples[:,:half,:] / np.max(samples[:,:half,:]) - samples[:,half:,:] = samples[:,half:,:] - np.min(samples[:,half:,:]) - samples[:,half:,:] = samples[:,half:,:] / np.max(samples[:,half:,:]) - - imsave(file_name, samples) - - def save(self, file_name): - g_file = file_name + '-G.pth' - d_file = file_name + '-D.pth' - torch.save(self.G.state_dict(), g_file) - torch.save(self.D.state_dict(), d_file) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--gpu', default='', type=str, help='gpu(s) to use.') - parser.add_argument('--train_kimg', default=600, type=float, help='# * 1000 real samples for each stabilizing training phase.') - parser.add_argument('--transition_kimg', default=600, type=float, help='# * 1000 real samples for each fading in phase.') - parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') - parser.add_argument('--beta1', default=0, type=float, help='beta1 for adam') - parser.add_argument('--beta2', default=0.99, type=float, help='beta2 for adam') - parser.add_argument('--gan', default='lsgan', type=str, help='model: lsgan/wgan_gp/gan, currently only support lsgan or gan with no_noise option.') - parser.add_argument('--first_resol', default=4, type=int, help='first resolution') - parser.add_argument('--target_resol', default=256, type=int, help='target resolution') - parser.add_argument('--drift', default=1e-3, type=float, help='drift, only available for wgan_gp.') - parser.add_argument('--sample_freq', default=500, type=int, help='sampling frequency.') - parser.add_argument('--save_freq', default=5000, type=int, help='save model frequency.') - parser.add_argument('--exp_dir', default='./exp', type=str, help='experiment dir.') - parser.add_argument('--no_noise', action='store_true', help='do not add noise to real data.') - - # TODO: support conditional inputs - - args = parser.parse_args() - opts = {k:v for k,v in args._get_kwargs()} - - latent_size = 512 - sigmoid_at_end = args.gan in ['lsgan', 'gan'] - - G = Generator(num_channels=3, latent_size=latent_size, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=False) - D = Discriminator(num_channels=3, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=sigmoid_at_end) - print(G) - print(D) - data = CelebA() - noise = RandomNoiseGenerator(latent_size, 'gaussian') - pggan = PGGAN(G, D, data, noise, opts) - pggan.train() +# -*- coding: utf-8 -*- +import torch +import torch.optim as optim +from torch.autograd import Variable +import sys, os, time +sys.path.append('utils') +sys.path.append('models') +from utils.data import CelebA, RandomNoiseGenerator +from models.model import Generator, Discriminator +import argparse +import numpy as np +from scipy.misc import imsave +from utils.logger import Logger + +class PGGAN(): + def __init__(self, G, D, data, noise, opts): + self.G = G + self.D = D + self.data = data + self.noise = noise + self.opts = opts + self.current_time = time.strftime('%Y-%m-%d %H%M%S') + self.logger = Logger('./logs/' + self.current_time + "/") + gpu = self.opts['gpu'] + self.use_cuda = len(gpu) > 0 + os.environ['CUDA_VISIBLE_DEVICES'] = gpu + + current_time = time.strftime('%Y-%m-%d %H%M%S') + self.opts['sample_dir'] = os.path.join(os.path.join(self.opts['exp_dir'], current_time), 'samples') + self.opts['ckpt_dir'] = os.path.join(os.path.join(self.opts['exp_dir'], current_time), 'ckpts') + os.makedirs(self.opts['sample_dir']) + os.makedirs(self.opts['ckpt_dir']) + + self.bs_map = {2**R: self.get_bs(2**R) for R in range(2, 11)} + self.rows_map = {32: 8, 16: 4, 8: 4, 4: 2, 2: 2} + + # save opts + with open(os.path.join(os.path.join(self.opts['exp_dir'], current_time), 'options.txt'), 'w') as f: + for k, v in self.opts.items(): + print('%s: %s' % (k, v), file=f) + print('batch_size_map: %s' % self.bs_map, file=f) + + def get_bs(self, resolution): + R = int(np.log2(resolution)) + if R < 7: + bs = 32 / 2**(max(0, R-4)) + else: + bs = 8 / 2**(min(2, R-7)) + return int(bs) + + def register_on_gpu(self): + if self.use_cuda: + self.G.cuda() + self.D.cuda() + + def create_optimizer(self): + self.optim_G = optim.Adam(self.G.parameters(), lr=self.opts['g_lr_max'], betas=(self.opts['beta1'], self.opts['beta2'])) + self.optim_D = optim.Adam(self.D.parameters(), lr=self.opts['d_lr_max'], betas=(self.opts['beta1'], self.opts['beta2'])) + + def create_criterion(self): + # w is for gan + if self.opts['gan'] == 'lsgan': + self.adv_criterion = lambda p,t,w: torch.mean((p-t)**2) # sigmoid is applied here + elif self.opts['gan'] == 'wgan_gp': + self.adv_criterion = lambda p,t,w: (-2*t+1) * torch.mean(p) + elif self.opts['gan'] == 'gan': + lambda p,t,w: -w*(torch.mean(t*torch.log(p+1e-8)) + torch.mean((1-t)*torch.log(1-p+1e-8))) + else: + raise ValueError('Invalid/Unsupported GAN: %s.' % self.opts['gan']) + + def compute_adv_loss(self, prediction, target, w): + return self.adv_criterion(prediction, target, w) + + def compute_additional_g_loss(self): + return 0.0 + + def compute_additional_d_loss(self): # drifting loss and gradient penalty, weighting inside this function + return 0.0 + + def _get_data(self, d): + return d.data[0] if isinstance(d, Variable) else d + + def compute_G_loss(self): + g_adv_loss = self.compute_adv_loss(self.d_fake, True, 1) + g_add_loss = self.compute_additional_g_loss() + self.g_adv_loss = self._get_data(g_adv_loss) + self.g_add_loss = self._get_data(g_add_loss) + return g_adv_loss + g_add_loss + + def compute_D_loss(self): + self.d_adv_loss_real = self.compute_adv_loss(self.d_real, True, 0.5) + self.d_adv_loss_fake = self.compute_adv_loss(self.d_fake, False, 0.5) * self.opts['fake_weight'] + d_adv_loss = self.d_adv_loss_real + self.d_adv_loss_fake + d_add_loss = self.compute_additional_d_loss() + self.d_adv_loss = self._get_data(d_adv_loss) + self.d_add_loss = self._get_data(d_add_loss) + + return d_adv_loss + d_add_loss + + def postprocess(self): + # TODO: weight cliping or others + pass + + def _numpy2var(self, x): + var = Variable(torch.from_numpy(x)) + if self.use_cuda: + var = var.cuda() + return var + + def _var2numpy(self, var): + if self.use_cuda: + return var.cpu().data.numpy() + return var.data.numpy() + + def add_noise(self, x): + # TODO: support more method of adding noise. + if self.opts.get('no_noise', False): + return x + + if hasattr(self, '_d_'): + self._d_ = self._d_ * 0.9 + torch.mean(self.d_real).data[0] * 0.1 + else: + self._d_ = 0.0 + strength = 0.2 * max(0, self._d_ - 0.5)**2 + noise = self._numpy2var(np.random.randn(*x.size()).astype(np.float32) * strength) + return x + noise + + def preprocess(self, z, real): + self.z = self._numpy2var(z) + self.real = self._numpy2var(real) + + def forward_G(self, cur_level): + self.d_fake = self.D(self.fake, cur_level=cur_level) + + def forward_D(self, cur_level, detach=True): + self.fake = self.G(self.z, cur_level=cur_level) + self.d_real = self.D(self.add_noise(self.real), cur_level=cur_level) + self.d_fake = self.D(self.fake.detach() if detach else self.fake, cur_level=cur_level) + # print('d_real', self.d_real.view(-1)) + # print('d_fake', self.d_fake.view(-1)) + # print(self.fake[0].view(-1)) + + def backward_G(self): + g_loss = self.compute_G_loss() + g_loss.backward() + self.optim_G.step() + self.g_loss = self._get_data(g_loss) + + def backward_D(self, retain_graph=False): + d_loss = self.compute_D_loss() + d_loss.backward(retain_graph=retain_graph) + self.optim_D.step() + self.d_loss = self._get_data(d_loss) + + def report(self, it, num_it, phase, resol): + formation = 'Iter[%d|%d], %s, %s, G: %.3f, D: %.3f, G_adv: %.3f, G_add: %.3f, D_adv: %.3f, D_add: %.3f' + values = (it, num_it, phase, resol, self.g_loss, self.d_loss, self.g_adv_loss, self.g_add_loss, self.d_adv_loss, self.d_add_loss) + print(formation % values) + + def tensorboard(self, it, num_it, phase, resol, samples): + # (1) Log the scalar values + prefix = str(resol)+'/'+phase+'/' + info = {prefix + 'G_loss': self.g_loss, + prefix + 'G_adv_loss': self.g_adv_loss, + prefix + 'G_add_loss': self.g_add_loss, + prefix + 'D_loss': self.d_loss, + prefix + 'D_adv_loss': self.d_adv_loss, + prefix + 'D_add_loss': self.d_add_loss, + prefix + 'D_adv_loss_fake': self._get_data(self.d_adv_loss_fake), + prefix + 'D_adv_loss_real': self._get_data(self.d_adv_loss_real)} + + for tag, value in info.items(): + self.logger.scalar_summary(tag, value, it) + + # (2) Log values and gradients of the parameters (histogram) + for tag, value in self.G.named_parameters(): + tag = tag.replace('.', '/') + self.logger.histo_summary('G/' + prefix +tag, self._var2numpy(value), it) + if value.grad is not None: + self.logger.histo_summary('G/' + prefix +tag + '/grad', self._var2numpy(value.grad), it) + + for tag, value in self.D.named_parameters(): + tag = tag.replace('.', '/') + self.logger.histo_summary('D/' + prefix + tag, self._var2numpy(value), it) + if value.grad is not None: + self.logger.histo_summary('D/' + prefix + tag + '/grad', + self._var2numpy(value.grad), it) + + # (3) Log the images + # info = {'images': samples[:10]} + # for tag, images in info.items(): + # logger.image_summary(tag, images, it) + + def train(self): + # prepare + self.create_optimizer() + self.create_criterion() + self.registe_on_gpu() + + to_level = int(np.log2(self.opts['target_resol'])) + from_level = int(np.log2(self.opts['first_resol'])) + assert 2**to_level == self.opts['target_resol'] and 2**from_level == self.opts['first_resol'] and to_level >= from_level >= 2 + cur_level = from_level + + for R in range(from_level-1, to_level-1): + batch_size = self.bs_map[2 ** (R+1)] + train_kimg = int(self.opts['train_kimg'] * 1000) + transition_kimg = int(self.opts['transition_kimg'] * 1000) + if R == to_level-1: + transition_kimg = 0 + cur_nimg = 0 + _len = len(str(train_kimg + transition_kimg)) + _num_it = (train_kimg + transition_kimg) // batch_size + for it in range(_num_it): + # determined current level: int for stabilizing and float for fading in + cur_level = R + float(max(cur_nimg-train_kimg, 0)) / transition_kimg + cur_resol = 2 ** int(np.ceil(cur_level+1)) + phase = 'stabilize' if int(cur_level) == cur_level else 'fade_in' + + # get a batch noise and real images + z = self.noise(batch_size) + x = self.data(batch_size, cur_resol, cur_level) + + # preprocess + self.preprocess(z, x) + + # update D + self.optim_D.zero_grad() + self.forward_D(cur_level, detach=True) # TODO: feed gdrop_strength + self.backward_D() + + # update G + self.optim_G.zero_grad() + self.forward_G(cur_level) + self.backward_G() + + # report + self.report(it, _num_it, phase, cur_resol) + + cur_nimg += batch_size + + # sampling + samples = [] + if (it % self.opts['sample_freq'] == 0) or it == _num_it-1: + samples = self.sample() + imsave(os.path.join(self.opts['sample_dir'], + '%dx%d-%s-%s.png' % (cur_resol, cur_resol, phase, str(it).zfill(6))), samples) + + # ===tensorboard visualization=== + if (it % self.opts['sample_freq'] == 0) or it == _num_it - 1: + self.tensorboard(it, _num_it, phase, cur_resol, samples) + + # save model + if (it % self.opts['save_freq'] == 0 and it > 0) or it == _num_it-1: + self.save(os.path.join(self.opts['ckpt_dir'], '%dx%d-%s-%s' % (cur_resol, cur_resol, phase, str(it).zfill(6)))) + + def sample(self, file_name): + batch_size = self.z.size(0) + n_row = self.rows_map[batch_size] + n_col = int(np.ceil(batch_size / float(n_row))) + samples = [] + i = j = 0 + for row in range(n_row): + one_row = [] + # fake + for col in range(n_col): + one_row.append(self.fake[i].cpu().data.numpy()) + i += 1 + # real + for col in range(n_col): + one_row.append(self.real[j].cpu().data.numpy()) + j += 1 + samples += [np.concatenate(one_row, axis=2)] + samples = np.concatenate(samples, axis=1).transpose([1, 2, 0]) + + half = samples.shape[1] // 2 + samples[:,:half,:] = samples[:,:half,:] - np.min(samples[:,:half,:]) + samples[:,:half,:] = samples[:,:half,:] / np.max(samples[:,:half,:]) + samples[:,half:,:] = samples[:,half:,:] - np.min(samples[:,half:,:]) + samples[:,half:,:] = samples[:,half:,:] / np.max(samples[:,half:,:]) + return samples + + def save(self, file_name): + g_file = file_name + '-G.pth' + d_file = file_name + '-D.pth' + torch.save(self.G.state_dict(), g_file) + torch.save(self.D.state_dict(), d_file) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--gpu', default='', type=str, help='gpu(s) to use.') + parser.add_argument('--train_kimg', default=600, type=float, help='# * 1000 real samples for each stabilizing training phase.') + parser.add_argument('--transition_kimg', default=600, type=float, help='# * 1000 real samples for each fading in phase.') + parser.add_argument('--g_lr_max', default=1e-3, type=float, help='Generator learning rate') + parser.add_argument('--d_lr_max', default=1e-3, type=float, help='Discriminator learning rate') + parser.add_argument('--beta1', default=0, type=float, help='beta1 for adam') + parser.add_argument('--beta2', default=0.99, type=float, help='beta2 for adam') + parser.add_argument('--gan', default='lsgan', type=str, help='model: lsgan/wgan_gp/gan, currently only support lsgan or gan with no_noise option.') + parser.add_argument('--first_resol', default=4, type=int, help='first resolution') + parser.add_argument('--target_resol', default=256, type=int, help='target resolution') + parser.add_argument('--drift', default=1e-3, type=float, help='drift, only available for wgan_gp.') + parser.add_argument('--sample_freq', default=500, type=int, help='sampling frequency.') + parser.add_argument('--save_freq', default=5000, type=int, help='save model frequency.') + parser.add_argument('--exp_dir', default='./exp', type=str, help='experiment dir.') + parser.add_argument('--no_noise', action='store_true', help='do not add noise to real data.') + + # TODO: support conditional inputs + + args = parser.parse_args() + opts = {k:v for k,v in args._get_kwargs()} + + latent_size = 512 + sigmoid_at_end = args.gan in ['lsgan', 'gan'] + + G = Generator(num_channels=3, latent_size=latent_size, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, tanh_at_end=False) + D = Discriminator(num_channels=3, resolution=args.target_resol, fmap_max=latent_size, fmap_base=8192, sigmoid_at_end=sigmoid_at_end) + print(G) + print(D) + data = CelebA() + noise = RandomNoiseGenerator(latent_size, 'gaussian') + pggan = PGGAN(G, D, data, noise, opts) + pggan.train() diff --git a/utils/data.py b/utils/data.py old mode 100644 new mode 100755 index a098056..0837d6f --- a/utils/data.py +++ b/utils/data.py @@ -1,98 +1,100 @@ -# -*- coding: utf-8 -*- -import os, scipy.misc -from glob import glob -import numpy as np -import h5py - - -prefix = './datasets/' - -def get_img(img_path, is_crop=True, crop_h=256, resize_h=64, normalize=False): - img = scipy.misc.imread(img_path, mode='RGB').astype(np.float) - resize_w = resize_h - if is_crop: - crop_w = crop_h - h, w = img.shape[:2] - j = int(round((h - crop_h)/2.)) - i = int(round((w - crop_w)/2.)) - cropped_image = scipy.misc.imresize(img[j:j+crop_h, i:i+crop_w],[resize_h, resize_w]) - else: - cropped_image = scipy.misc.imresize(img,[resize_h, resize_w]) - if normalize: - cropped_image = cropped_image/127.5 - 1.0 - return np.transpose(cropped_image, [2, 0, 1]) - - -# class CelebA(): -# def __init__(self): -# datapath = os.path.join(prefix, 'celeba/aligned') -# self.channel = 3 -# self.data = glob(os.path.join(datapath, '*.jpg')) - -# def __call__(self, batch_size, size): -# batch_number = len(self.data)/batch_size -# path_list = [self.data[i] for i in np.random.randint(len(self.data), size=batch_size)] -# file_list = [p.split('/')[-1] for p in path_list] -# batch = [get_img(img_path, True, 178, size, True) for img_path in path_list] -# batch_imgs = np.array(batch).astype(np.float32) -# return batch_imgs - -# def save_imgs(self, samples, file_name): -# N_samples, channel, height, width = samples.shape -# N_row = N_col = int(np.ceil(N_samples**0.5)) -# combined_imgs = np.ones((channel, N_row*height, N_col*width)) -# for i in range(N_row): -# for j in range(N_col): -# if i*N_col+j < samples.shape[0]: -# combined_imgs[:,i*height:(i+1)*height, j*width:(j+1)*width] = samples[i*N_col+j] -# combined_imgs = np.transpose(combined_imgs, [1, 2, 0]) -# scipy.misc.imsave(file_name+'.png', combined_imgs) - - -class CelebA(): - def __init__(self): - datapath = 'celeba-hq-1024x1024.h5' - resolution = ['data2x2', 'data4x4', 'data8x8', 'data16x16', 'data32x32', 'data64x64', \ - 'data128x128', 'data256x256', 'data512x512', 'data1024x1024'] - self._base_key = 'data' - self.dataset = h5py.File(os.path.join(prefix, datapath), 'r') - self._len = {k:len(self.dataset[k]) for k in resolution} - assert all([resol in self.dataset.keys() for resol in resolution]) - - def __call__(self, batch_size, size, level=None): - key = self._base_key + '{}x{}'.format(size, size) - idx = np.random.randint(self._len[key], size=batch_size) - batch_x = np.array([self.dataset[key][i]/127.5-1.0 for i in idx], dtype=np.float32) - if level is not None: - if level != int(level): - min_lw, max_lw = int(level+1)-level, level-int(level) - lr_key = self._base_key + '{}x{}'.format(size//2, size//2) - low_resol_batch_x = np.array([self.dataset[lr_key][i]/127.5-1.0 for i in idx], dtype=np.float32).repeat(2, axis=2).repeat(2, axis=3) - batch_x = batch_x * max_lw + low_resol_batch_x * min_lw - return batch_x - - def save_imgs(self, samples, file_name): - N_samples, channel, height, width = samples.shape - N_row = N_col = int(np.ceil(N_samples**0.5)) - combined_imgs = np.ones((channel, N_row*height, N_col*width)) - for i in range(N_row): - for j in range(N_col): - if i*N_col+j < samples.shape[0]: - combined_imgs[:,i*height:(i+1)*height, j*width:(j+1)*width] = samples[i*N_col+j] - combined_imgs = np.transpose(combined_imgs, [1, 2, 0]) - scipy.misc.imsave(file_name+'.png', combined_imgs) - - -class RandomNoiseGenerator(): - def __init__(self, size, noise_type='gaussian'): - self.size = size - self.noise_type = noise_type.lower() - assert self.noise_type in ['gaussian', 'uniform'] - self.generator_map = {'gaussian': np.random.randn, 'uniform': np.random.uniform} - if self.noise_type == 'gaussian': - self.generator = lambda s: np.random.randn(*s) - elif self.noise_type == 'uniform': - self.generator = lambda s: np.random.uniform(-1, 1, size=s) - - def __call__(self, batch_size): - return self.generator([batch_size, self.size]).astype(np.float32) +# -*- coding: utf-8 -*- +import os, scipy.misc +from glob import glob +import numpy as np +import h5py + + +#prefix = 'C:\\Users\\yuan\\Downloads' +# prefix = '/Users/yuan/Downloads/' +prefix = './datasets/' + +def get_img(img_path, is_crop=True, crop_h=256, resize_h=64, normalize=False): + img = scipy.misc.imread(img_path, mode='RGB').astype(np.float) + resize_w = resize_h + if is_crop: + crop_w = crop_h + h, w = img.shape[:2] + j = int(round((h - crop_h)/2.)) + i = int(round((w - crop_w)/2.)) + cropped_image = scipy.misc.imresize(img[j:j+crop_h, i:i+crop_w],[resize_h, resize_w]) + else: + cropped_image = scipy.misc.imresize(img,[resize_h, resize_w]) + if normalize: + cropped_image = cropped_image/127.5 - 1.0 + return np.transpose(cropped_image, [2, 0, 1]) + + +# class CelebA(): +# def __init__(self): +# datapath = os.path.join(prefix, 'celeba/aligned') +# self.channel = 3 +# self.data = glob(os.path.join(datapath, '*.jpg')) + +# def __call__(self, batch_size, size): +# batch_number = len(self.data)/batch_size +# path_list = [self.data[i] for i in np.random.randint(len(self.data), size=batch_size)] +# file_list = [p.split('/')[-1] for p in path_list] +# batch = [get_img(img_path, True, 178, size, True) for img_path in path_list] +# batch_imgs = np.array(batch).astype(np.float32) +# return batch_imgs + +# def save_imgs(self, samples, file_name): +# N_samples, channel, height, width = samples.shape +# N_row = N_col = int(np.ceil(N_samples**0.5)) +# combined_imgs = np.ones((channel, N_row*height, N_col*width)) +# for i in range(N_row): +# for j in range(N_col): +# if i*N_col+j < samples.shape[0]: +# combined_imgs[:,i*height:(i+1)*height, j*width:(j+1)*width] = samples[i*N_col+j] +# combined_imgs = np.transpose(combined_imgs, [1, 2, 0]) +# scipy.misc.imsave(file_name+'.png', combined_imgs) + + +class CelebA(): + def __init__(self): + datapath = 'celeba-hq-1024x1024.h5' + resolution = ['data2x2', 'data4x4', 'data8x8', 'data16x16', 'data32x32', 'data64x64', \ + 'data128x128', 'data256x256', 'data512x512', 'data1024x1024'] + self._base_key = 'data' + self.dataset = h5py.File(os.path.join(prefix, datapath), 'r') + self._len = {k:len(self.dataset[k]) for k in resolution} + assert all([resol in self.dataset.keys() for resol in resolution]) + + def __call__(self, batch_size, size, level=None): + key = self._base_key + '{}x{}'.format(size, size) + idx = np.random.randint(self._len[key], size=batch_size) + batch_x = np.array([self.dataset[key][i]/127.5-1.0 for i in idx], dtype=np.float32) + if level is not None: + if level != int(level): + min_lw, max_lw = int(level+1)-level, level-int(level) + lr_key = self._base_key + '{}x{}'.format(size//2, size//2) + low_resol_batch_x = np.array([self.dataset[lr_key][i]/127.5-1.0 for i in idx], dtype=np.float32).repeat(2, axis=2).repeat(2, axis=3) + batch_x = batch_x * max_lw + low_resol_batch_x * min_lw + return batch_x + + def save_imgs(self, samples, file_name): + N_samples, channel, height, width = samples.shape + N_row = N_col = int(np.ceil(N_samples**0.5)) + combined_imgs = np.ones((channel, N_row*height, N_col*width)) + for i in range(N_row): + for j in range(N_col): + if i*N_col+j < samples.shape[0]: + combined_imgs[:,i*height:(i+1)*height, j*width:(j+1)*width] = samples[i*N_col+j] + combined_imgs = np.transpose(combined_imgs, [1, 2, 0]) + scipy.misc.imsave(file_name+'.png', combined_imgs) + + +class RandomNoiseGenerator(): + def __init__(self, size, noise_type='gaussian'): + self.size = size + self.noise_type = noise_type.lower() + assert self.noise_type in ['gaussian', 'uniform'] + self.generator_map = {'gaussian': np.random.randn, 'uniform': np.random.uniform} + if self.noise_type == 'gaussian': + self.generator = lambda s: np.random.randn(*s) + elif self.noise_type == 'uniform': + self.generator = lambda s: np.random.uniform(-1, 1, size=s) + + def __call__(self, batch_size): + return self.generator([batch_size, self.size]).astype(np.float32) diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..694c985 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +""" +Reference : https://github.com/SherlockLiao/pytorch-beginner/tree/master/04-Convolutional%20Neural%20Network +""" +import tensorflow as tf +import numpy as np +import scipy.misc +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + + +class Logger(object): + + def __init__(self, log_dir): + """Create a summary writer logging to log_dir.""" + self.writer = tf.summary.FileWriter(log_dir) + + def scalar_summary(self, tag, value, step): + """Log a scalar variable.""" + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, + simple_value=value)]) + self.writer.add_summary(summary, step) + + def image_summary(self, tag, images, step): + """Log a list of images.""" + + img_summaries = [] + for i, img in enumerate(images): + # Write the image to a string + try: + s = StringIO() + except: + s = BytesIO() + scipy.misc.toimage(img).save(s, format="png") + + # Create an Image object + img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), + height=img.shape[0], + width=img.shape[1]) + # Create a Summary value + img_summaries.append( + tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) + + # Create and write Summary + summary = tf.Summary(value=img_summaries) + self.writer.add_summary(summary, step) + + def histo_summary(self, tag, values, step, bins=1000): + """Log a histogram of the tensor of values.""" + + # Create a histogram using numpy + counts, bin_edges = np.histogram(values, bins=bins) + + # Fill the fields of the histogram proto + hist = tf.HistogramProto() + hist.min = float(np.min(values)) + hist.max = float(np.max(values)) + hist.num = int(np.prod(values.shape)) + hist.sum = float(np.sum(values)) + hist.sum_squares = float(np.sum(values**2)) + + # Drop the start of the first bin + bin_edges = bin_edges[1:] + + # Add bin edges and counts + for edge in bin_edges: + hist.bucket_limit.append(edge) + for c in counts: + hist.bucket.append(c) + + # Create and write Summary + summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) + self.writer.add_summary(summary, step) + self.writer.flush()