Skip to content

Commit

Permalink
Bugs updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Pharaun85 committed Oct 22, 2020
1 parent 8144d58 commit 88bce6f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
20 changes: 13 additions & 7 deletions dataloaders/sequencedataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def __init__(self, folders, transform=None):
for folder in folders:
folder_image_02 = os.path.join(folder, 'image_02')
for image_02_file in os.listdir(folder_image_02):
if os.path.isfile(os.path.join(folder_image_02, image_02_file)) and '.png' in image_02_file:
_, ext = os.path.splitext(image_02_file)
if os.path.isfile(os.path.join(folder_image_02, image_02_file)) and (ext == '.png'):
image_02.append(os.path.join(folder_image_02, image_02_file))

self.image_02 = image_02
Expand All @@ -117,12 +118,17 @@ def __getitem__(self, idx):
image = Image.open(imagepath)

# Obtaining ground truth
head, tail = os.path.split(imagepath)
head, _ = os.path.split(head)
filename, _ = os.path.splitext(tail)
gt_path = os.path.join(head, 'frames_topology.txt')
gtdata = pd.read_csv(gt_path, sep=';', header=None, dtype=str)
gTruth = int(gtdata.loc[gtdata[0] == filename][2])
if os.path.isfile(imagepath+'.json'):
with open(imagepath+'.json') as json_file:
gTruth_info = json.load(json_file)
gTruth = int(gTruth_info['label'])
else:
head, tail = os.path.split(imagepath)
head, _ = os.path.split(head)
filename, _ = os.path.splitext(tail)
gt_path = os.path.join(head, 'frames_topology.txt')
gtdata = pd.read_csv(gt_path, sep=';', header=None, dtype=str)
gTruth = int(gtdata.loc[gtdata[0] == filename][2])

sample = {'data': image,
'label': gTruth}
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,8 @@ def main(args, model=None):
if args.wandb_group_id:
group_id = args.wandb_group_id
else:
group_id = 'Kitti2011_RGB'
group_id = 'Kitti2011_Homography'

print(args)
warnings.filterwarnings("ignore")

Expand Down

0 comments on commit 88bce6f

Please sign in to comment.