Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add requirements.txt and progress bar by using progressbar2. #27

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import argparse

from models import *
from utils import progress_bar
from utils import get_progress_bar, update_progress_bar
from torch.autograd import Variable


Expand Down Expand Up @@ -82,10 +82,12 @@
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
print('Train')
net.train()
train_loss = 0
correct = 0
total = 0
progress_bar_obj = get_progress_bar(len(trainloader))
for batch_idx, (inputs, targets) in enumerate(trainloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
Expand All @@ -101,15 +103,17 @@ def train(epoch):
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
update_progress_bar(progress_bar_obj, index=batch_idx, loss=(train_loss / (batch_idx + 1)),
acc=(correct / total), c=correct, t=total)

def test(epoch):
print('\nTest')
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
progress_bar_obj = get_progress_bar(len(testloader))
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
Expand All @@ -122,13 +126,13 @@ def test(epoch):
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
update_progress_bar(progress_bar_obj, index=batch_idx, loss=(test_loss / (batch_idx + 1)),
acc=(correct / total), c=correct, t=total)

# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
print('\nSaving..')
state = {
'net': net.module if use_cuda else net,
'acc': acc,
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch
torchvision
progressbar2
71 changes: 24 additions & 47 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import math

import progressbar
import torch.nn as nn
import torch.nn.init as init

Expand Down Expand Up @@ -42,54 +43,30 @@ def init_params(net):
init.constant(m.bias, 0)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
def get_progress_bar(total):
format_custom_text = progressbar.FormatCustomText(
'Loss: %(loss).3f | Acc: %(acc).3f%% (%(c)d/%(t)d)',
dict(
loss=0,
acc=0,
c=0,
t=0,
),
)
prog_bar = progressbar.ProgressBar(0, total, widgets=[
progressbar.Counter(), ' of {} '.format(total),
progressbar.Bar(),
' ', progressbar.ETA(),
' ', format_custom_text
])
return prog_bar, format_custom_text


def update_progress_bar(progress_bar_obj, index=None, loss=None, acc=None, c=None, t=None):
prog_bar, format_custom_text = progress_bar_obj
format_custom_text.update_mapping(loss=loss, acc=acc, c=c, t=t)
prog_bar.update(index)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.

cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')

cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time

L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)

msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')

# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))

if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()

def format_time(seconds):
days = int(seconds / 3600/24)
Expand Down