Skip to content

Commit

Permalink
Make python3 compatibility changes in main and utils files.
Browse files Browse the repository at this point in the history
  • Loading branch information
Karan Desai committed Oct 22, 2018
1 parent 250455b commit d5d17f3
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 19 deletions.
18 changes: 8 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
import numpy as np
import time
import os
from six.moves import cPickle
import pickle
import torch.backends.cudnn as cudnn
import yaml

import opts
import misc.eval_utils
import misc.utils as utils
import misc.AttModel as AttModel
from misc import utils, eval_utils, AttModel
import yaml

# from misc.rewards import get_self_critical_reward
Expand Down Expand Up @@ -282,8 +280,8 @@ def eval(opt):
info_path = os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')

# open old infos and check if models are compatible
with open(info_path) as f:
infos = cPickle.load(f)
with open(info_path, 'rb') as f:
infos = pickle.load(f)
saved_model_opt = infos['opt']

# opt.learning_rate = saved_model_opt.learning_rate
Expand All @@ -292,7 +290,7 @@ def eval(opt):

if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
histories = cPickle.load(f)
histories = pickle.load(f)

if opt.decode_noc:
model._reinit_word_weight(opt, dataset.ctoi, dataset.wtoi)
Expand Down Expand Up @@ -383,9 +381,9 @@ def eval(opt):
histories['lr_history'] = lr_history
histories['ss_prob_history'] = ss_prob_history
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
cPickle.dump(infos, f)
pickle.dump(infos, f)
with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
cPickle.dump(histories, f)
pickle.dump(histories, f)

if best_flag:
checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
Expand All @@ -396,4 +394,4 @@ def eval(opt):

print("model saved to {} with best cider score {:.3f}".format(checkpoint_path, best_val_score))
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
cPickle.dump(infos, f)
pickle.dump(infos, f)
Binary file added misc/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added misc/__pycache__/eval_utils.cpython-36.pyc
Binary file not shown.
6 changes: 3 additions & 3 deletions misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def forward(self, txt_input, vis_input, target):
txt_mask = Variable(txt_mask)
txt_out = - torch.masked_select(txt_select, txt_mask.view(-1,1))

loss = (torch.sum(txt_out)+torch.sum(vis_out)) / (torch.sum(txt_mask.data) + torch.sum(vis_mask.data))
loss = (torch.sum(txt_out)+torch.sum(vis_out)).float() / (torch.sum(txt_mask.data) + torch.sum(vis_mask.data)).float()

return loss

Expand All @@ -249,7 +249,7 @@ def forward(self, input, target):
select = torch.gather(input.view(-1,2), 1, Variable(new_target))

out = - torch.masked_select(select, bn_mask)
loss = torch.sum(out) / torch.sum(bn_mask.data)
loss = torch.sum(out).float() / torch.sum(bn_mask.data).float()
else:
loss = Variable(input.data.new(1).zero_()).float()

Expand All @@ -272,7 +272,7 @@ def forward(self, input, target):

if torch.sum(attr_mask.data) > 0:
out = - torch.masked_select(select, attr_mask)
loss = torch.sum(out) / torch.sum(attr_mask.data)
loss = torch.sum(out).float() / torch.sum(attr_mask.data).float()
else:
loss = Variable(input.data.new(1).zero_()).float()

Expand Down
26 changes: 26 additions & 0 deletions pooling/make.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env bash

# CUDA_PATH=/usr/local/cuda/

export CUDA_PATH=/usr/local/cuda/
#You may also want to ad the following
#export C_INCLUDE_PATH=/opt/cuda/include

export CXXFLAGS="-std=c++11"
export CFLAGS="-std=c99"

CUDA_ARCH="-gencode arch=compute_30,code=sm_30 \
-gencode arch=compute_35,code=sm_35 \
-gencode arch=compute_50,code=sm_50 \
-gencode arch=compute_52,code=sm_52 \
-gencode arch=compute_60,code=sm_60 \
-gencode arch=compute_61,code=sm_61 "

# compile roi_align
cd roi_align/src
echo "Compiling roi align kernels by nvcc..."
nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu \
-D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -arch=$CUDA_ARCH
cd ../
python build.py
cd ../
5 changes: 4 additions & 1 deletion pooling/roi_align/_ext/roi_align/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
locals[symbol] = _wrap_function(fn, _ffi)
if callable(fn):
locals[symbol] = _wrap_function(fn, _ffi)
else:
locals[symbol] = fn
__all__.append(symbol)

_import_symbols(locals())
Binary file added pooling/roi_align/_ext/roi_align/_roi_align.so
Binary file not shown.
6 changes: 5 additions & 1 deletion pooling/roi_align/functions/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def forward(self, features, rois):
self.spatial_scale, features,
rois, output)
else:
raise NotImplementedError
roi_align.roi_align_forward(self.aligned_height,
self.aligned_width,
self.spatial_scale, features,
rois, output)
# raise NotImplementedError

return output

Expand Down
23 changes: 19 additions & 4 deletions pooling/roi_align/make.sh
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
#!/usr/bin/env bash

CUDA_PATH=/usr/local/cuda/
# CUDA_PATH=/usr/local/cuda/

cd src
echo "Compiling my_lib kernels by nvcc..."
nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
export CUDA_PATH=/usr/local/cuda/
#You may also want to ad the following
#export C_INCLUDE_PATH=/opt/cuda/include

export CXXFLAGS="-std=c++11"
export CFLAGS="-std=c99"

CUDA_ARCH="-gencode arch=compute_30,code=sm_30 \
-gencode arch=compute_35,code=sm_35 \
-gencode arch=compute_50,code=sm_50 \
-gencode arch=compute_52,code=sm_52 \
-gencode arch=compute_60,code=sm_60 \
-gencode arch=compute_61,code=sm_61 "

# compile roi_align
cd src
echo "Compiling roi align kernels by nvcc..."
nvcc -c -o roi_align_kernel.cu.o roi_align_kernel.cu \
-D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CUDA_ARCH
cd ../
python build.py
190 changes: 190 additions & 0 deletions pooling/roi_align/src/roi_align.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#include <TH/TH.h>
#include <math.h>
#include <omp.h>


void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois,
const int height, const int width, const int channels,
const int aligned_height, const int aligned_width, const float * bottom_rois,
float* top_data);

void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois,
const int height, const int width, const int channels,
const int aligned_height, const int aligned_width, const float * bottom_rois,
float* top_data);

int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale,
THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output)
{
//Grab the input tensor
float * data_flat = THFloatTensor_data(features);
float * rois_flat = THFloatTensor_data(rois);

float * output_flat = THFloatTensor_data(output);

// Number of ROIs
int num_rois = THFloatTensor_size(rois, 0);
int size_rois = THFloatTensor_size(rois, 1);
if (size_rois != 5)
{
return 0;
}

// data height
int data_height = THFloatTensor_size(features, 2);
// data width
int data_width = THFloatTensor_size(features, 3);
// Number of channels
int num_channels = THFloatTensor_size(features, 1);

// do ROIAlignForward
ROIAlignForwardCpu(data_flat, spatial_scale, num_rois, data_height, data_width, num_channels,
aligned_height, aligned_width, rois_flat, output_flat);

return 1;
}

int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale,
THFloatTensor * top_grad, THFloatTensor * rois, THFloatTensor * bottom_grad)
{
//Grab the input tensor
float * top_grad_flat = THFloatTensor_data(top_grad);
float * rois_flat = THFloatTensor_data(rois);

float * bottom_grad_flat = THFloatTensor_data(bottom_grad);

// Number of ROIs
int num_rois = THFloatTensor_size(rois, 0);
int size_rois = THFloatTensor_size(rois, 1);
if (size_rois != 5)
{
return 0;
}

// batch size
// int batch_size = THFloatTensor_size(bottom_grad, 0);
// data height
int data_height = THFloatTensor_size(bottom_grad, 2);
// data width
int data_width = THFloatTensor_size(bottom_grad, 3);
// Number of channels
int num_channels = THFloatTensor_size(bottom_grad, 1);

// do ROIAlignBackward
ROIAlignBackwardCpu(top_grad_flat, spatial_scale, num_rois, data_height,
data_width, num_channels, aligned_height, aligned_width, rois_flat, bottom_grad_flat);

return 1;
}

void ROIAlignForwardCpu(const float* bottom_data, const float spatial_scale, const int num_rois,
const int height, const int width, const int channels,
const int aligned_height, const int aligned_width, const float * bottom_rois,
float* top_data)
{
const int output_size = num_rois * aligned_height * aligned_width * channels;

int idx = 0;
for (idx = 0; idx < output_size; ++idx)
{
// (n, c, ph, pw) is an element in the aligned output
int pw = idx % aligned_width;
int ph = (idx / aligned_width) % aligned_height;
int c = (idx / aligned_width / aligned_height) % channels;
int n = idx / aligned_width / aligned_height / channels;

float roi_batch_ind = bottom_rois[n * 5 + 0];
float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;

// Force malformed ROI to be 1x1
float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.);
float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.);
float bin_size_h = roi_height / (aligned_height - 1.);
float bin_size_w = roi_width / (aligned_width - 1.);

float h = (float)(ph) * bin_size_h + roi_start_h;
float w = (float)(pw) * bin_size_w + roi_start_w;

int hstart = fminf(floor(h), height - 2);
int wstart = fminf(floor(w), width - 2);

int img_start = roi_batch_ind * channels * height * width;

// bilinear interpolation
if (h < 0 || h >= height || w < 0 || w >= width)
{
top_data[idx] = 0.;
}
else
{
float h_ratio = h - (float)(hstart);
float w_ratio = w - (float)(wstart);
int upleft = img_start + (c * height + hstart) * width + wstart;
int upright = upleft + 1;
int downleft = upleft + width;
int downright = downleft + 1;

top_data[idx] = bottom_data[upleft] * (1. - h_ratio) * (1. - w_ratio)
+ bottom_data[upright] * (1. - h_ratio) * w_ratio
+ bottom_data[downleft] * h_ratio * (1. - w_ratio)
+ bottom_data[downright] * h_ratio * w_ratio;
}
}
}

void ROIAlignBackwardCpu(const float* top_diff, const float spatial_scale, const int num_rois,
const int height, const int width, const int channels,
const int aligned_height, const int aligned_width, const float * bottom_rois,
float* bottom_diff)
{
const int output_size = num_rois * aligned_height * aligned_width * channels;

int idx = 0;
for (idx = 0; idx < output_size; ++idx)
{
// (n, c, ph, pw) is an element in the aligned output
int pw = idx % aligned_width;
int ph = (idx / aligned_width) % aligned_height;
int c = (idx / aligned_width / aligned_height) % channels;
int n = idx / aligned_width / aligned_height / channels;

float roi_batch_ind = bottom_rois[n * 5 + 0];
float roi_start_w = bottom_rois[n * 5 + 1] * spatial_scale;
float roi_start_h = bottom_rois[n * 5 + 2] * spatial_scale;
float roi_end_w = bottom_rois[n * 5 + 3] * spatial_scale;
float roi_end_h = bottom_rois[n * 5 + 4] * spatial_scale;

// Force malformed ROI to be 1x1
float roi_width = fmaxf(roi_end_w - roi_start_w + 1., 0.);
float roi_height = fmaxf(roi_end_h - roi_start_h + 1., 0.);
float bin_size_h = roi_height / (aligned_height - 1.);
float bin_size_w = roi_width / (aligned_width - 1.);

float h = (float)(ph) * bin_size_h + roi_start_h;
float w = (float)(pw) * bin_size_w + roi_start_w;

int hstart = fminf(floor(h), height - 2);
int wstart = fminf(floor(w), width - 2);

int img_start = roi_batch_ind * channels * height * width;

// bilinear interpolation
if (h < 0 || h >= height || w < 0 || w >= width)
{
float h_ratio = h - (float)(hstart);
float w_ratio = w - (float)(wstart);
int upleft = img_start + (c * height + hstart) * width + wstart;
int upright = upleft + 1;
int downleft = upleft + width;
int downright = downleft + 1;

bottom_diff[upleft] += top_diff[idx] * (1. - h_ratio) * (1. - w_ratio);
bottom_diff[upright] += top_diff[idx] * (1. - h_ratio) * w_ratio;
bottom_diff[downleft] += top_diff[idx] * h_ratio * (1. - w_ratio);
bottom_diff[downright] += top_diff[idx] * h_ratio * w_ratio;
}
}
}
5 changes: 5 additions & 0 deletions pooling/roi_align/src/roi_align.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
int roi_align_forward(int aligned_height, int aligned_width, float spatial_scale,
THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output);

int roi_align_backward(int aligned_height, int aligned_width, float spatial_scale,
THFloatTensor * top_grad, THFloatTensor * rois, THFloatTensor * bottom_grad);

0 comments on commit d5d17f3

Please sign in to comment.