diff --git a/discriminators/discriminator.py b/discriminators/discriminator.py index 1f26b66..a13a833 100644 --- a/discriminators/discriminator.py +++ b/discriminators/discriminator.py @@ -12,16 +12,15 @@ def __init__(self,in_channels,negative_slope = 0.2): self._in_channels = in_channels self._negative_slope = negative_slope - self.conv1 = nn.Conv2d(in_channels=self._in_channels,out_channels=64,kernel_size=4,stride=2,padding=1) - self.relu1 = nn.LeakyReLU(self._negative_slope) - self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1) - self.relu2 = nn.LeakyReLU(self._negative_slope) - self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1) - self.relu3 = nn.LeakyReLU(self._negative_slope) - self.conv4 = nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=1) - self.relu4 = nn.LeakyReLU(self._negative_slope) - - self.conv5 = nn.Conv2d(in_channels=512,out_channels=1,kernel_size=4,stride=2,padding=1) + self.conv1 = nn.Conv2d(in_channels=self._in_channels,out_channels=64,kernel_size=4,stride=2,padding=2) + self.relu1 = nn.LeakyReLU(self._negative_slope,inplace=True) + self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=2) + self.relu2 = nn.LeakyReLU(self._negative_slope,inplace=True) + self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=2) + self.relu3 = nn.LeakyReLU(self._negative_slope,inplace=True) + self.conv4 = nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=2) + self.relu4 = nn.LeakyReLU(self._negative_slope,inplace=True) + self.conv5 = nn.Conv2d(in_channels=512,out_channels=2,kernel_size=4,stride=2,padding=2) def forward(self,x): x= self.conv1(x) # -,-,161,161 @@ -33,7 +32,6 @@ def forward(self,x): x= self.conv4(x) # -,-,21,21 x = self.relu4(x) x = self.conv5(x) # -,-,11,11 - # upsample x = F.upsample_bilinear(x,scale_factor=2) x = x[:,:,:-1,:-1] # -,-, 21,21 diff --git a/train_base.py b/train_base.py index 5d1f92a..bf3ba0b 100644 --- a/train_base.py +++ b/train_base.py @@ -34,6 +34,9 @@ def main(): and cls (GT Segmentation) folder") parser.add_argument("--mode",help="base (baseline),adv (adversarial), semi \ (semi-supervised)",choices=('base','adv','semi'),default='base') + parser.add_argument("--lam_adv",help="Weight for Adversarial loss for Segmentation Network training",\ + default=0.01) + parser.add_argument("--nogpu",help="Train only on cpus. Helpful for debugging",action='store_true') parser.add_argument("--max_epoch",help="Maximum iterations.",default=20,\ type=int) parser.add_argument("--start_epoch",help="Resume training from this epoch",\ @@ -86,8 +89,8 @@ def main(): print("Discriminator Loaded") # Assumptions made. Paper doesn't clarify the details - optimizer_D = optim.Adam(filter(lambda p: p.requires_grad, \ - discriminator.parameters()),lr=0.0001,weight_decay=0.0001) + optimizer_D = optim.SGD(filter(lambda p: p.requires_grad, \ + discriminator.parameters()),lr=0.0001,weight_decay=0.0001,momentum=0.5,nesterov=True) print("Discriminator Optimizer Loaded") # Load the snapshot if available @@ -109,63 +112,107 @@ def main(): new_state.update(saved_net) generator.load_state_dict(new_state) - generator = nn.DataParallel(generator).cuda() - print("Generator Setup for Parallel Training") + if not args.nogpu: + generator = nn.DataParallel(generator).cuda() + # generator = generator.cuda(0) + print("Generator Setup for Parallel Training") + # print("Generator Loaded on device 0") + else: + print("No Parallel Training for CPU") + if args.mode == 'adv': - discriminator = nn.DataParallel(discriminator).cuda() - print("Discriminator Setup for Parallel Training") + if args.nogpu: + print("No Parallel Training for CPU") + else: + discriminator = nn.DataParallel(discriminator).cuda() + # discriminator = discriminator.cuda(1) + print("Discriminator Setup for parallel training") + # print("Discriminator Loaded on device 1") best_miou = -1 - print('Training Going to Start') for epoch in range(args.start_epoch,args.max_epoch+1): generator.train() for batch_id, (img,mask,ohmask) in enumerate(trainloader): - img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda(),requires_grad=False),\ - Variable(ohmask.cuda(),requires_grad=False) - - # Generator Step + if args.nogpu: + img,mask,ohmask = Variable(img),Variable(mask,requires_grad=False),\ + Variable(ohmask,requires_grad=False) + else: + img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda(),requires_grad=False),\ + Variable(ohmask.cuda(),requires_grad=False) + + # Generate Prediction Map with the Segmentation Network out_img_map = generator(img) out_img_map = nn.LogSoftmax()(out_img_map) - import pdb; pdb.set_trace() - #Discriminator Step - conf_map_true = nn.LogSoftmax()(discriminator(ohmask)) - conf_map_true = torch.cat((1- conf_map_true,conf_map_true),dim = 1) - - conf_map_fake = nn.LogSoftmax()(discriminator(out_img_map)) - conf_map_fake = torch.cat((1- conf_map_fake,conf_map_fake),dim = 1) - - target_true = Variable(torch.ones(conf_map_true.size()).squeeze(1).cuda(),requires_grad=False) - target_fake = Variable(torch.zeros(conf_map_fake.size()).squeeze(1).cuda(),requires_grad=False) - - # Segmentation loss - L_seg = nn.NLLLoss2d()(out_img_map,mask) + ##################### + # Baseline Training # + #################### + if args.mode == 'base': + print("Baseline Training") + L_seg = nn.NLLLoss2d()(out_img_map,mask) - if args.mode == 'adv': - L_seg += nn.NLLLoss2d()(conf_map_fake,target_true) + i = len(trainloader)*(epoch-1) + batch_id + poly_lr_scheduler(optimizer_G, 0.00025, i) + optimizer_G.zero_grad() + L_seg.backward(retain_variables=True) + optimizer_G.step() + print("[{}][{}]Loss: {}".format(epoch,i,L_seg.data[0])) + ####################### + # Adverarial Training # + ####################### if args.mode == 'adv': - # Discriminator Loss - L_d = nn.NLLLoss2d()(conf_map_true,target_true) + nn.NLLLoss2d()(conf_map_fake,target_fake) - - - i = len(trainloader)*(epoch-1) + batch_id - poly_lr_scheduler(optimizer_G, 0.00025, i) - - # Generator Step - optimizer_G.zero_grad() - L_seg.backward() - optimizer_G.step() - - # Discriminator step - optimizer_D.zero_grad() - L_d.backward() - optimizer_D.step() - - print("Epoch {} Finished!".format(epoch)) + N = out_img_map.size()[0] + H = out_img_map.size()[2] + W = out_img_map.size()[3] + + # Generate the Real and Fake Labels + target_fake = Variable(torch.zeros((N,H,W)).long(),requires_grad=False) + target_real = Variable(torch.ones((N,H,W)).long(),requires_grad=False) + if not args.nogpu: + target_fake = target_fake.cuda() + target_real = target_real.cuda() + + ######################### + # Discriminator Training# + ######################### + + # Train on Real + conf_map_real = nn.LogSoftmax()(discriminator(ohmask.float())) + + optimizer_D.zero_grad() + + LD_real = nn.NLLLoss2d()(conf_map_real,target_real) + LD_real.backward() + + # Train on Fake + conf_map_fake = nn.LogSoftmax()(discriminator(Variable(out_img_map.data))) + LD_fake = nn.NLLLoss2d()(conf_map_fake,target_fake) + LD_fake.backward() + + # Update Discriminator weights + i = len(trainloader)*(epoch-1) + batch_id + poly_lr_scheduler(optimizer_D, 0.00025, i) + + optimizer_D.step() + + ###################### + # Generator Training # + ##################### + conf_map_fake = nn.LogSoftmax()(discriminator(out_img_map)) + LG_ce = nn.NLLLoss2d()(out_img_map,mask) + LG_adv = args.lam_adv * nn.NLLLoss2d()(conf_map_fake,target_real) + + LG_seg = LG_ce.data[0] + LG_adv.data[0] + optimizer_G.zero_grad() + LG_ce.backward(retain_variables=True) + LG_adv.backward() + poly_lr_scheduler(optimizer_G, 0.00025, i) + optimizer_G.step() + print("[{}][{}] LD: {} LG: {}".format(epoch,i,(LD_real + LD_fake).data[0],LG_seg)) snapshot = { 'epoch': epoch, 'state_dict': generator.state_dict(), diff --git a/utils/validate.py b/utils/validate.py index 38dfea3..3cfc355 100644 --- a/utils/validate.py +++ b/utils/validate.py @@ -11,16 +11,23 @@ from utils.transforms import RandomSizedCrop, IgnoreLabelClass, ToTensorLabel, NormalizeOwn, ZeroPadding from torchvision.transforms import ToTensor,Compose -def val(model,valoader,nclass=21): +def val(model,valoader,nclass=21,nogpu=False): model.eval() gts, preds = [], [] - for img_id, (img,gt_mask) in enumerate(valoader): + for img_id, (img,gt_mask,_) in enumerate(valoader): gt_mask = gt_mask.numpy()[0] - img = Variable(img.cuda(),volatile=True) + if nogpu: + img = Variable(img,volatile=True) + else: + img = Variable(img.cuda(),volatile=True) out_pred_map = model(img) # Get hard prediction - soft_pred = out_pred_map.data.cpu().numpy()[0] + if nogpu: + soft_pred = out_pred_map.data.numpy()[0] + else: + soft_pred = out_pred_map.data.cpu().numpy()[0] + soft_pred = soft_pred[:,:gt_mask.shape[0],:gt_mask.shape[1]] hard_pred = np.argmax(soft_pred,axis=0).astype(np.uint8) for gt_, pred_ in zip(gt_mask, hard_pred):