diff --git a/main.py b/main.py index 2f04e34e8..77833a4fe 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ import argparse from models import * -from utils import progress_bar +from utils import get_progress_bar, update_progress_bar from torch.autograd import Variable @@ -82,10 +82,12 @@ # Training def train(epoch): print('\nEpoch: %d' % epoch) + print('Train') net.train() train_loss = 0 correct = 0 total = 0 + progress_bar_obj = get_progress_bar(len(trainloader)) for batch_idx, (inputs, targets) in enumerate(trainloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() @@ -101,15 +103,17 @@ def train(epoch): total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() - progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + update_progress_bar(progress_bar_obj, index=batch_idx, loss=(train_loss / (batch_idx + 1)), + acc=(correct / total), c=correct, t=total) def test(epoch): + print('\nTest') global best_acc net.eval() test_loss = 0 correct = 0 total = 0 + progress_bar_obj = get_progress_bar(len(testloader)) for batch_idx, (inputs, targets) in enumerate(testloader): if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() @@ -122,13 +126,13 @@ def test(epoch): total += targets.size(0) correct += predicted.eq(targets.data).cpu().sum() - progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' - % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) + update_progress_bar(progress_bar_obj, index=batch_idx, loss=(test_loss / (batch_idx + 1)), + acc=(correct / total), c=correct, t=total) # Save checkpoint. acc = 100.*correct/total if acc > best_acc: - print('Saving..') + print('\nSaving..') state = { 'net': net.module if use_cuda else net, 'acc': acc, diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..ffa212e97 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch +torchvision +progressbar2 \ No newline at end of file diff --git a/utils.py b/utils.py index 4c9b3f90c..59978e778 100644 --- a/utils.py +++ b/utils.py @@ -8,6 +8,7 @@ import time import math +import progressbar import torch.nn as nn import torch.nn.init as init @@ -42,54 +43,30 @@ def init_params(net): init.constant(m.bias, 0) -_, term_width = os.popen('stty size', 'r').read().split() -term_width = int(term_width) +def get_progress_bar(total): + format_custom_text = progressbar.FormatCustomText( + 'Loss: %(loss).3f | Acc: %(acc).3f%% (%(c)d/%(t)d)', + dict( + loss=0, + acc=0, + c=0, + t=0, + ), + ) + prog_bar = progressbar.ProgressBar(0, total, widgets=[ + progressbar.Counter(), ' of {} '.format(total), + progressbar.Bar(), + ' ', progressbar.ETA(), + ' ', format_custom_text + ]) + return prog_bar, format_custom_text + + +def update_progress_bar(progress_bar_obj, index=None, loss=None, acc=None, c=None, t=None): + prog_bar, format_custom_text = progress_bar_obj + format_custom_text.update_mapping(loss=loss, acc=acc, c=c, t=t) + prog_bar.update(index) -TOTAL_BAR_LENGTH = 65. -last_time = time.time() -begin_time = last_time -def progress_bar(current, total, msg=None): - global last_time, begin_time - if current == 0: - begin_time = time.time() # Reset for new bar. - - cur_len = int(TOTAL_BAR_LENGTH*current/total) - rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 - - sys.stdout.write(' [') - for i in range(cur_len): - sys.stdout.write('=') - sys.stdout.write('>') - for i in range(rest_len): - sys.stdout.write('.') - sys.stdout.write(']') - - cur_time = time.time() - step_time = cur_time - last_time - last_time = cur_time - tot_time = cur_time - begin_time - - L = [] - L.append(' Step: %s' % format_time(step_time)) - L.append(' | Tot: %s' % format_time(tot_time)) - if msg: - L.append(' | ' + msg) - - msg = ''.join(L) - sys.stdout.write(msg) - for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): - sys.stdout.write(' ') - - # Go back to the center of the bar. - for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): - sys.stdout.write('\b') - sys.stdout.write(' %d/%d ' % (current+1, total)) - - if current < total-1: - sys.stdout.write('\r') - else: - sys.stdout.write('\n') - sys.stdout.flush() def format_time(seconds): days = int(seconds / 3600/24)