Skip to content

Commit

Permalink
Partial Adversarial Training
Browse files Browse the repository at this point in the history
  • Loading branch information
mohitzsh committed Nov 26, 2017
1 parent 4c36202 commit aab59b4
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 59 deletions.
20 changes: 9 additions & 11 deletions discriminators/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ def __init__(self,in_channels,negative_slope = 0.2):
self._in_channels = in_channels
self._negative_slope = negative_slope

self.conv1 = nn.Conv2d(in_channels=self._in_channels,out_channels=64,kernel_size=4,stride=2,padding=1)
self.relu1 = nn.LeakyReLU(self._negative_slope)
self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1)
self.relu2 = nn.LeakyReLU(self._negative_slope)
self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1)
self.relu3 = nn.LeakyReLU(self._negative_slope)
self.conv4 = nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=1)
self.relu4 = nn.LeakyReLU(self._negative_slope)

self.conv5 = nn.Conv2d(in_channels=512,out_channels=1,kernel_size=4,stride=2,padding=1)
self.conv1 = nn.Conv2d(in_channels=self._in_channels,out_channels=64,kernel_size=4,stride=2,padding=2)
self.relu1 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=2)
self.relu2 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=2)
self.relu3 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv4 = nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=2)
self.relu4 = nn.LeakyReLU(self._negative_slope,inplace=True)
self.conv5 = nn.Conv2d(in_channels=512,out_channels=2,kernel_size=4,stride=2,padding=2)

def forward(self,x):
x= self.conv1(x) # -,-,161,161
Expand All @@ -33,7 +32,6 @@ def forward(self,x):
x= self.conv4(x) # -,-,21,21
x = self.relu4(x)
x = self.conv5(x) # -,-,11,11

# upsample
x = F.upsample_bilinear(x,scale_factor=2)
x = x[:,:,:-1,:-1] # -,-, 21,21
Expand Down
135 changes: 91 additions & 44 deletions train_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def main():
and cls (GT Segmentation) folder")
parser.add_argument("--mode",help="base (baseline),adv (adversarial), semi \
(semi-supervised)",choices=('base','adv','semi'),default='base')
parser.add_argument("--lam_adv",help="Weight for Adversarial loss for Segmentation Network training",\
default=0.01)
parser.add_argument("--nogpu",help="Train only on cpus. Helpful for debugging",action='store_true')
parser.add_argument("--max_epoch",help="Maximum iterations.",default=20,\
type=int)
parser.add_argument("--start_epoch",help="Resume training from this epoch",\
Expand Down Expand Up @@ -86,8 +89,8 @@ def main():
print("Discriminator Loaded")

# Assumptions made. Paper doesn't clarify the details
optimizer_D = optim.Adam(filter(lambda p: p.requires_grad, \
discriminator.parameters()),lr=0.0001,weight_decay=0.0001)
optimizer_D = optim.SGD(filter(lambda p: p.requires_grad, \
discriminator.parameters()),lr=0.0001,weight_decay=0.0001,momentum=0.5,nesterov=True)
print("Discriminator Optimizer Loaded")

# Load the snapshot if available
Expand All @@ -109,63 +112,107 @@ def main():
new_state.update(saved_net)
generator.load_state_dict(new_state)

generator = nn.DataParallel(generator).cuda()
print("Generator Setup for Parallel Training")
if not args.nogpu:
generator = nn.DataParallel(generator).cuda()
# generator = generator.cuda(0)
print("Generator Setup for Parallel Training")
# print("Generator Loaded on device 0")
else:
print("No Parallel Training for CPU")


if args.mode == 'adv':
discriminator = nn.DataParallel(discriminator).cuda()
print("Discriminator Setup for Parallel Training")
if args.nogpu:
print("No Parallel Training for CPU")
else:
discriminator = nn.DataParallel(discriminator).cuda()
# discriminator = discriminator.cuda(1)
print("Discriminator Setup for parallel training")
# print("Discriminator Loaded on device 1")

best_miou = -1
print('Training Going to Start')
for epoch in range(args.start_epoch,args.max_epoch+1):
generator.train()
for batch_id, (img,mask,ohmask) in enumerate(trainloader):
img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda(),requires_grad=False),\
Variable(ohmask.cuda(),requires_grad=False)

# Generator Step
if args.nogpu:
img,mask,ohmask = Variable(img),Variable(mask,requires_grad=False),\
Variable(ohmask,requires_grad=False)
else:
img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda(),requires_grad=False),\
Variable(ohmask.cuda(),requires_grad=False)

# Generate Prediction Map with the Segmentation Network
out_img_map = generator(img)
out_img_map = nn.LogSoftmax()(out_img_map)

import pdb; pdb.set_trace()
#Discriminator Step
conf_map_true = nn.LogSoftmax()(discriminator(ohmask))
conf_map_true = torch.cat((1- conf_map_true,conf_map_true),dim = 1)

conf_map_fake = nn.LogSoftmax()(discriminator(out_img_map))
conf_map_fake = torch.cat((1- conf_map_fake,conf_map_fake),dim = 1)

target_true = Variable(torch.ones(conf_map_true.size()).squeeze(1).cuda(),requires_grad=False)
target_fake = Variable(torch.zeros(conf_map_fake.size()).squeeze(1).cuda(),requires_grad=False)

# Segmentation loss
L_seg = nn.NLLLoss2d()(out_img_map,mask)
#####################
# Baseline Training #
####################
if args.mode == 'base':
print("Baseline Training")
L_seg = nn.NLLLoss2d()(out_img_map,mask)

if args.mode == 'adv':
L_seg += nn.NLLLoss2d()(conf_map_fake,target_true)
i = len(trainloader)*(epoch-1) + batch_id
poly_lr_scheduler(optimizer_G, 0.00025, i)

optimizer_G.zero_grad()
L_seg.backward(retain_variables=True)
optimizer_G.step()
print("[{}][{}]Loss: {}".format(epoch,i,L_seg.data[0]))

#######################
# Adverarial Training #
#######################
if args.mode == 'adv':
# Discriminator Loss
L_d = nn.NLLLoss2d()(conf_map_true,target_true) + nn.NLLLoss2d()(conf_map_fake,target_fake)


i = len(trainloader)*(epoch-1) + batch_id
poly_lr_scheduler(optimizer_G, 0.00025, i)

# Generator Step
optimizer_G.zero_grad()
L_seg.backward()
optimizer_G.step()

# Discriminator step
optimizer_D.zero_grad()
L_d.backward()
optimizer_D.step()

print("Epoch {} Finished!".format(epoch))

N = out_img_map.size()[0]
H = out_img_map.size()[2]
W = out_img_map.size()[3]

# Generate the Real and Fake Labels
target_fake = Variable(torch.zeros((N,H,W)).long(),requires_grad=False)
target_real = Variable(torch.ones((N,H,W)).long(),requires_grad=False)
if not args.nogpu:
target_fake = target_fake.cuda()
target_real = target_real.cuda()

#########################
# Discriminator Training#
#########################

# Train on Real
conf_map_real = nn.LogSoftmax()(discriminator(ohmask.float()))

optimizer_D.zero_grad()

LD_real = nn.NLLLoss2d()(conf_map_real,target_real)
LD_real.backward()

# Train on Fake
conf_map_fake = nn.LogSoftmax()(discriminator(Variable(out_img_map.data)))
LD_fake = nn.NLLLoss2d()(conf_map_fake,target_fake)
LD_fake.backward()

# Update Discriminator weights
i = len(trainloader)*(epoch-1) + batch_id
poly_lr_scheduler(optimizer_D, 0.00025, i)

optimizer_D.step()

######################
# Generator Training #
#####################
conf_map_fake = nn.LogSoftmax()(discriminator(out_img_map))
LG_ce = nn.NLLLoss2d()(out_img_map,mask)
LG_adv = args.lam_adv * nn.NLLLoss2d()(conf_map_fake,target_real)

LG_seg = LG_ce.data[0] + LG_adv.data[0]
optimizer_G.zero_grad()
LG_ce.backward(retain_variables=True)
LG_adv.backward()
poly_lr_scheduler(optimizer_G, 0.00025, i)
optimizer_G.step()
print("[{}][{}] LD: {} LG: {}".format(epoch,i,(LD_real + LD_fake).data[0],LG_seg))
snapshot = {
'epoch': epoch,
'state_dict': generator.state_dict(),
Expand Down
15 changes: 11 additions & 4 deletions utils/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
from utils.transforms import RandomSizedCrop, IgnoreLabelClass, ToTensorLabel, NormalizeOwn, ZeroPadding
from torchvision.transforms import ToTensor,Compose

def val(model,valoader,nclass=21):
def val(model,valoader,nclass=21,nogpu=False):
model.eval()
gts, preds = [], []
for img_id, (img,gt_mask) in enumerate(valoader):
for img_id, (img,gt_mask,_) in enumerate(valoader):
gt_mask = gt_mask.numpy()[0]
img = Variable(img.cuda(),volatile=True)
if nogpu:
img = Variable(img,volatile=True)
else:
img = Variable(img.cuda(),volatile=True)
out_pred_map = model(img)

# Get hard prediction
soft_pred = out_pred_map.data.cpu().numpy()[0]
if nogpu:
soft_pred = out_pred_map.data.numpy()[0]
else:
soft_pred = out_pred_map.data.cpu().numpy()[0]

soft_pred = soft_pred[:,:gt_mask.shape[0],:gt_mask.shape[1]]
hard_pred = np.argmax(soft_pred,axis=0).astype(np.uint8)
for gt_, pred_ in zip(gt_mask, hard_pred):
Expand Down

0 comments on commit aab59b4

Please sign in to comment.