Skip to content

Commit

Permalink
dimension fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jeoseo committed Apr 25, 2022
1 parent 902fc6c commit 97c3148
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ logs/
results/
__pycache__/
code/dataset/filenames/nyudepthv2/
code/dataset/weights/
code/dataset/weights/*
2 changes: 1 addition & 1 deletion code/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def readTXT(self, txt_path):

def augment_training_data(self, image, depth):
H, W, C = image.shape

if self.count % 4 == 0:
alpha = random.random()
beta = random.random()
Expand Down Expand Up @@ -68,4 +69,3 @@ def augment_test_data(self, image, depth):
depth = self.to_tensor(depth).squeeze()

return image, depth

4 changes: 2 additions & 2 deletions code/dataset/nyudepthv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __getitem__(self, idx):

H, W, C = image.shape
#first, image and depth must be cropped because of the large borders on the depth map
image=image[int(W/10),int(W*19/20),int(W/12):int(W*11/12)]
depth=depth[int(W/10),int(W*19/20),int(W/12):int(W*11/12)]
image=image[40:460,40:600,:]
depth=depth[40:460,40:600]

if self.is_train and self.do_cutdepth:
image, depth = self.augment_training_data(image, depth)
Expand Down
5 changes: 0 additions & 5 deletions code/dataset/sunrgbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ def __getitem__(self, idx):
if self.scale_size:
image = cv2.resize(image, (self.scale_size[1], self.scale_size[0]))
depth = cv2.resize(depth, (self.scale_size[1], self.scale_size[0]))
H, W, C = image.shape
#first, image and depth must be cropped because of the large borders on the depth map
image=image[int(W/10),int(W*19/20),int(W/12):int(W*11/12)]
depth=depth[int(W/10),int(W*19/20),int(W/12):int(W*11/12)]

if self.is_train and self.do_cutdepth:
image, depth = self.augment_training_data(image, depth)
else:
Expand Down
22 changes: 11 additions & 11 deletions code/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ def __init__(self, max_depth=10.0, is_train=False):
self.max_depth = max_depth

self.encoder = mit_b4()
if is_train:
ckpt_path = './GLPDepth/code/models/weights/mit_b4.pth'
try:
load_checkpoint(self.encoder, ckpt_path, logger=None)
except:
import gdown
print("Download pre-trained encoder weights...")
id = '1BUtU42moYrOFbsMCE-LTTkUE-mrWnfG2'
url = 'https://drive.google.com/uc?id=' + id
output = './GLPDepth/code/models/weights/mit_b4.pth'
gdown.download(url, output, quiet=False)
# if is_train:
# ckpt_path = './GLPDepth/code/models/weights/mit_b4.pth'
# try:
# load_checkpoint(self.encoder, ckpt_path, logger=None)
# except:
# import gdown
# print("Download pre-trained encoder weights...")
# id = '1BUtU42moYrOFbsMCE-LTTkUE-mrWnfG2'
# url = 'https://drive.google.com/uc?id=' + id
# output = './GLPDepth/code/models/weights/mit_b4.pth'
# gdown.download(url, output, quiet=False)

channels_in = [512, 320, 128]
channels_out = 64
Expand Down
4 changes: 2 additions & 2 deletions code/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def main():
start=time.time()

# Dataset setting
dataset_kwargs = {'dataset_name': args.dataset, 'data_path': args.data_path,'filenames_path':args.filenames_path}
dataset_kwargs = {'dataset_name': args.dataset, 'data_path': args.data_path,'filenames_path':args.filenames_path,'do_cutdepth':args.do_cutdepth}
if args.dataset == 'nyudepthv2':
dataset_kwargs['crop_size'] = (448, 576)
dataset_kwargs['crop_size'] = (448-32,576-32)
elif args.dataset == 'kitti':
dataset_kwargs['crop_size'] = (352, 704)
else:
Expand Down

0 comments on commit 97c3148

Please sign in to comment.