diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..ccb802c78 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +data/ +checkpoint/ +__pycache__/ + +# for pycharm +.idea/ + +# for log +*.log + +# checkpoints are too big +saved_ckpt/ + +# tar +*.tar diff --git a/main.py b/main.py index c1b46ee5e..d93c418e3 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,6 @@ +# -*- coding: utf-8 -*- '''Train CIFAR10 with PyTorch.''' + from __future__ import print_function import torch @@ -18,7 +20,7 @@ parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') -parser.add_argument('--lr', default=0.1, type=float, help='learning rate') +parser.add_argument('--lr', default=0.001, type=float, help='learning rate') parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') args = parser.parse_args() @@ -50,7 +52,7 @@ # Model print('==> Building model..') -# net = VGG('VGG19') +net = VGG('VGG16') # net = ResNet18() # net = PreActResNet18() # net = GoogLeNet() @@ -61,7 +63,7 @@ # net = DPN92() # net = ShuffleNetG2() # net = SENet18() -net = ShuffleNetV2(1) +# net = ShuffleNetV2(1) net = net.to(device) if device == 'cuda': net = torch.nn.DataParallel(net) @@ -74,10 +76,12 @@ checkpoint = torch.load('./checkpoint/ckpt.t7') net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] - start_epoch = checkpoint['epoch'] + start_epoch = checkpoint['epoch'] + 1 -criterion = nn.CrossEntropyLoss() -optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) +# criterion = nn.CrossEntropyLoss() +criterion = KLDivLoss(reduction='mean'); +# optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) +optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4) # Training def train(epoch): @@ -104,11 +108,11 @@ def train(epoch): def test(epoch): global best_acc - net.eval() + net.eval() # 变为测试模式, 对dropout和batch normalization有影响 test_loss = 0 correct = 0 total = 0 - with torch.no_grad(): + with torch.no_grad(): # 运算不需要进行求导, 提高性能 for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) @@ -137,6 +141,6 @@ def test(epoch): best_acc = acc -for epoch in range(start_epoch, start_epoch+200): +for epoch in range(start_epoch, 61): train(epoch) test(epoch) diff --git a/main_L1Loss.py b/main_L1Loss.py new file mode 100644 index 000000000..e49054c8f --- /dev/null +++ b/main_L1Loss.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +'''Train CIFAR10 with PyTorch.''' + +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +import torchvision +import torchvision.transforms as transforms + +import os +import argparse + +from models import * +from utils import progress_bar + + +parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') +parser.add_argument('--lr', default=0.001, type=float, help='learning rate') +parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') +args = parser.parse_args() + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +best_acc = 0 # best test accuracy +start_epoch = 0 # start from epoch 0 or last checkpoint epoch + +# Data +print('==> Preparing data..') +transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) +trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) + +testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) +testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + +# Model +print('==> Building model..') +net = VGG('VGG16') +# net = ResNet18() +# net = PreActResNet18() +# net = GoogLeNet() +# net = DenseNet121() +# net = ResNeXt29_2x64d() +# net = MobileNet() +# net = MobileNetV2() +# net = DPN92() +# net = ShuffleNetG2() +# net = SENet18() +# net = ShuffleNetV2(1) +net = net.to(device) +if device == 'cuda': + net = torch.nn.DataParallel(net) + cudnn.benchmark = True + +if args.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..') + assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' + checkpoint = torch.load('./checkpoint/ckpt.t7') + net.load_state_dict(checkpoint['net']) + best_acc = checkpoint['acc'] + start_epoch = checkpoint['epoch'] + 1 + +# criterion = nn.CrossEntropyLoss() +# criterion = nn.MSELoss(reduction='sum') +criterion = nn.L1Loss(reduction='sum'); +# optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) +optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4) + +# Training +def train(epoch): + print('\nEpoch: %d' % epoch) + net.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + optimizer.zero_grad() + outputs = net(inputs) + targets_mse = torch.zeros(outputs.size()).to(device) + for i in range(0, targets_mse.size()[0]): + targets_mse[i][targets[i]] = 1; + softmax = F.softmax(outputs, dim=0) + # print(targets_mse) + # print(outputs, targets) + loss = criterion(softmax, targets_mse) + loss.backward() + optimizer.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + +def test(epoch): + global best_acc + net.eval() # 变为测试模式, 对dropout和batch normalization有影响 + test_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): # 运算不需要进行求导, 提高性能 + for batch_idx, (inputs, targets) in enumerate(testloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net(inputs) + targets_mse = torch.zeros(outputs.size()).to(device) + for i in range(0, targets_mse.size()[0]): + targets_mse[i][targets[i]] = 1; + softmax = F.softmax(outputs, dim=0) + loss = criterion(softmax, targets_mse) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + # Save checkpoint. + acc = 100.*correct/total + if acc > best_acc: + print('Saving..') + state = { + 'net': net.state_dict(), + 'acc': acc, + 'epoch': epoch, + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') + torch.save(state, './checkpoint/ckpt.t7') + best_acc = acc + + +for epoch in range(start_epoch, 61): + train(epoch) + test(epoch) diff --git a/main_mse.py b/main_mse.py new file mode 100644 index 000000000..67a58c145 --- /dev/null +++ b/main_mse.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +'''Train CIFAR10 with PyTorch.''' + +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.backends.cudnn as cudnn + +import torchvision +import torchvision.transforms as transforms + +import os +import argparse + +from models import * +from utils import progress_bar + + +parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') +parser.add_argument('--lr', default=0.001, type=float, help='learning rate') +parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') +args = parser.parse_args() + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +best_acc = 0 # best test accuracy +start_epoch = 0 # start from epoch 0 or last checkpoint epoch + +# Data +print('==> Preparing data..') +transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), +]) + +trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) +trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) + +testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) +testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + +# Model +print('==> Building model..') +net = VGG('VGG16') +# net = ResNet18() +# net = PreActResNet18() +# net = GoogLeNet() +# net = DenseNet121() +# net = ResNeXt29_2x64d() +# net = MobileNet() +# net = MobileNetV2() +# net = DPN92() +# net = ShuffleNetG2() +# net = SENet18() +# net = ShuffleNetV2(1) +net = net.to(device) +if device == 'cuda': + net = torch.nn.DataParallel(net) + cudnn.benchmark = True + +if args.resume: + # Load checkpoint. + print('==> Resuming from checkpoint..') + assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' + checkpoint = torch.load('./checkpoint/ckpt.t7') + net.load_state_dict(checkpoint['net']) + best_acc = checkpoint['acc'] + start_epoch = checkpoint['epoch'] + 1 + +# criterion = nn.CrossEntropyLoss() +criterion = nn.MSELoss(reduction='sum') +# optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) +optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4) + +# Training +def train(epoch): + print('\nEpoch: %d' % epoch) + net.train() + train_loss = 0 + correct = 0 + total = 0 + for batch_idx, (inputs, targets) in enumerate(trainloader): + inputs, targets = inputs.to(device), targets.to(device) + optimizer.zero_grad() + outputs = net(inputs) + targets_mse = torch.zeros(outputs.size()).to(device) + for i in range(0, targets_mse.size()[0]): + targets_mse[i][targets[i]] = 1; + softmax = F.softmax(outputs, dim=0) + # print(targets_mse) + # print(outputs, targets) + loss = criterion(softmax, targets_mse) + loss.backward() + optimizer.step() + + train_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) + +def test(epoch): + global best_acc + net.eval() # 变为测试模式, 对dropout和batch normalization有影响 + test_loss = 0 + correct = 0 + total = 0 + with torch.no_grad(): # 运算不需要进行求导, 提高性能 + for batch_idx, (inputs, targets) in enumerate(testloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net(inputs) + targets_mse = torch.zeros(outputs.size()).to(device) + for i in range(0, targets_mse.size()[0]): + targets_mse[i][targets[i]] = 1; + softmax = F.softmax(outputs, dim=0) + loss = criterion(softmax, targets_mse) + + test_loss += loss.item() + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + + progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' + % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) + + # Save checkpoint. + acc = 100.*correct/total + if acc > best_acc: + print('Saving..') + state = { + 'net': net.state_dict(), + 'acc': acc, + 'epoch': epoch, + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') + torch.save(state, './checkpoint/ckpt.t7') + best_acc = acc + + +for epoch in range(start_epoch, 61): + train(epoch) + test(epoch) diff --git a/models/__init__.py b/models/__init__.py index 7f67a9e55..877893903 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,14 +1,14 @@ -from vgg import * -from dpn import * -from lenet import * -from senet import * -from pnasnet import * -from densenet import * -from googlenet import * -from shufflenet import * -from shufflenetv2 import * -from resnet import * -from resnext import * -from preact_resnet import * -from mobilenet import * -from mobilenetv2 import * +from .vgg import * +from .dpn import * +from .lenet import * +from .senet import * +from .pnasnet import * +from .densenet import * +from .googlenet import * +from .shufflenet import * +from .shufflenetv2 import * +from .resnet import * +from .resnext import * +from .preact_resnet import * +from .mobilenet import * +from .mobilenetv2 import * diff --git a/models/shufflenetv2.py b/models/shufflenetv2.py index d24c5dcbb..1bf45320d 100644 --- a/models/shufflenetv2.py +++ b/models/shufflenetv2.py @@ -16,7 +16,7 @@ def forward(self, x): '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' N, C, H, W = x.size() g = self.groups - return x.view(N, g, C/g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) + return x.view(N, g, int(C/g), H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) class SplitBlock(nn.Module): diff --git a/mydata/0.jpg b/mydata/0.jpg new file mode 100644 index 000000000..920295e1a Binary files /dev/null and b/mydata/0.jpg differ diff --git a/mydata/1.jpg b/mydata/1.jpg new file mode 100644 index 000000000..535793e95 Binary files /dev/null and b/mydata/1.jpg differ diff --git a/mydata/2.jpg b/mydata/2.jpg new file mode 100644 index 000000000..071b7c717 Binary files /dev/null and b/mydata/2.jpg differ diff --git a/mydata/3.jpg b/mydata/3.jpg new file mode 100644 index 000000000..aab04be6a Binary files /dev/null and b/mydata/3.jpg differ diff --git a/mydata/4.jpg b/mydata/4.jpg new file mode 100644 index 000000000..6eafad786 Binary files /dev/null and b/mydata/4.jpg differ diff --git a/mydata/5.jpg b/mydata/5.jpg new file mode 100644 index 000000000..04a88ab1b Binary files /dev/null and b/mydata/5.jpg differ diff --git a/mydata/6.jpg b/mydata/6.jpg new file mode 100644 index 000000000..65d048242 Binary files /dev/null and b/mydata/6.jpg differ diff --git a/mydata/7.jpg b/mydata/7.jpg new file mode 100644 index 000000000..24a4d5ff5 Binary files /dev/null and b/mydata/7.jpg differ diff --git a/mydata/8.jpg b/mydata/8.jpg new file mode 100644 index 000000000..d5d5abd4a Binary files /dev/null and b/mydata/8.jpg differ diff --git a/mydata/gay.jpg b/mydata/gay.jpg new file mode 100644 index 000000000..c20c92719 Binary files /dev/null and b/mydata/gay.jpg differ diff --git a/mydata/images.jpg b/mydata/images.jpg new file mode 100644 index 000000000..1420adf98 Binary files /dev/null and b/mydata/images.jpg differ diff --git a/show/calc_cifar.py b/show/calc_cifar.py new file mode 100644 index 000000000..5d58f4cb1 --- /dev/null +++ b/show/calc_cifar.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +import torch +import torchvision +import torchvision.transforms as transforms + +import matplotlib.pyplot as plt +import numpy as np + +from load import loadnet + + +pic_num = 5 +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') +std = (0.2023, 0.1994, 0.2010) +mean = (0.4914, 0.4822, 0.4465) + +if __name__ == '__main__': + # 原因: 如 https://blog.csdn.net/xiemanR/article/details/71700531 + # test loader涉及多线程操作, 在windows环境下需要用__name__ == '__main__'包装 + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + + testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test) + testloader = torch.utils.data.DataLoader(testset, batch_size=pic_num, shuffle=True, num_workers=2) + + net, _ = loadnet(1) + + net.eval() # 变为测试模式, 对dropout和batch normalization有影响 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if torch.cuda.is_available(): + net.cuda() + to_pil_image = transforms.ToPILImage() + with torch.no_grad(): # 运算不需要进行求导, 提高性能 + (inputs, targets) = list(testloader)[0] + (inputs, targets) = inputs.to(device), targets.to(device) + outputs = net(inputs) + _, predicted = outputs.max(1) + for i in range(pic_num): + print("正确类别: %s, 计算类别: %s" % ( + classes[targets[i]], classes[predicted[i]])) # 显示label + img = inputs[i].new(*inputs[i].size()) + img[0, :, :] = inputs[i][0, :, :] * std[0] + mean[0] + img[1, :, :] = inputs[i][1, :, :] * std[1] + mean[1] + img[2, :, :] = inputs[i][2, :, :] * std[2] + mean[2] + if torch.cuda.is_available(): + img = to_pil_image(img.cpu()) + else: + img = to_pil_image(img) + plt.imshow(img) + plt.show() \ No newline at end of file diff --git a/show/calc_mine.py b/show/calc_mine.py new file mode 100644 index 000000000..a9fe7f8da --- /dev/null +++ b/show/calc_mine.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +import torch +import torchvision.transforms as transforms + +import tkinter.filedialog +import matplotlib.pyplot as plt +from PIL import Image + +from load import loadnet + + +pic_num = 5 +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') +std = (0.2023, 0.1994, 0.2010) +mean = (0.4914, 0.4822, 0.4465) + +fname = tkinter.filedialog.askopenfilename() + +if __name__ == '__main__': + # 原因: 如 https://blog.csdn.net/xiemanR/article/details/71700531 + # test loader涉及多线程操作, 在windows环境下需要用__name__ == '__main__'包装 + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + + img = Image.open(fname).convert('RGB') + plt.imshow(img) + plt.show() + img = img.resize((32, 32), Image.ANTIALIAS) + plt.imshow(img) + plt.show() + inputs = transform(img).reshape(1, 3, 32, 32) + + net, _ = loadnet(1) + + net.eval() # 变为测试模式, 对dropout和batch normalization有影响 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + if torch.cuda.is_available(): + net.cuda() + to_pil_image = transforms.ToPILImage() + with torch.no_grad(): # 运算不需要进行求导, 提高性能 + inputs = inputs.to(device) + outputs = net(inputs) + _, predicted = outputs.max(1) + print("计算类别: %s" % (classes[predicted[0]])) # 显示label \ No newline at end of file diff --git a/show/load.py b/show/load.py new file mode 100644 index 000000000..2b3ec1c40 --- /dev/null +++ b/show/load.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.backends.cudnn as cudnn + +import sys +sys.path.append("..") + +from models import * + + +class WrappedModel(nn.Module): + def __init__(self, module): + super(WrappedModel, self).__init__() + self.module = module # that I actually define. + def forward(self, x): + return self.module(x) + + +def loadnet(index): + if index == 1: + print('ckpt1: 用shufflenet v2训练的, 20个epoch, 80%, 主要用来试验一下代码之类的') + net = ShuffleNetV2(1) + fname = '../saved_ckpt/ckpt1' + elif index == 2: + print('ckpt2: 用shufflenet v2训练了两百个epoch, 80%, 貌似极限就是这样了') + net = ShuffleNetV2(1) + fname = '../saved_ckpt/ckpt2' + elif index == 3: + print('ckpt3: 用DenseNet训练到三十个epoch左右开始卡住了, 87%') + net = DenseNet121() + fname = '../saved_ckpt/ckpt3' + elif index == 4: + print('ckpt4: DenseNet121, 88%') + net = DenseNet121() + fname = '../saved_ckpt/ckpt4' + elif index == 5: + print('ckpt5: DPN92, 87%') + net = DPN92() + fname = '../saved_ckpt/ckpt5' + elif index == 6: + print('ckpt6: VGG16, 83%') + net = VGG('VGG16') + fname = '../saved_ckpt/ckpt6' + elif index == 7: + print('ckpt7: VGG16, 89%') + net = VGG('VGG16') + fname = '../saved_ckpt/ckpt7' + elif index == 8: + print('ckpt8: ResNeXT, 87%') + net = ResNeXt29_2x64d() + fname = '../saved_ckpt/ckpt8' + elif index == 9: + print('ckpt9: DenseNet, 90%') + net = DenseNet121() + fname = '../saved_ckpt/ckpt9' + elif index == 10: + print('ckpt10: VGG16, MSE, 60%') + net = VGG('VGG16') + fname = '../saved_ckpt/ckpt10' + elif index == 11: + print('ckpt11: VGG16, L1, 25%') + net = VGG('VGG16') + fname = '../saved_ckpt/ckpt11' + else: + print('Invalid index') + return + + if torch.cuda.is_available() == True: + checkpoint = torch.load(fname) + net = torch.nn.DataParallel(net) + net.load_state_dict(checkpoint['net']) + cudnn.benchmark = True + else: + checkpoint = torch.load(fname, map_location='cpu') + net = WrappedModel(net) + net.load_state_dict(checkpoint['net']) + acc = checkpoint['acc'] + + return [net, acc] + + +if __name__ == '__main__': + [net, acc] = loadnet(11) + print ("Accuracy: %f" % acc) diff --git a/show/show.py b/show/show.py new file mode 100644 index 000000000..19829a43d --- /dev/null +++ b/show/show.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +import torch +import torchvision +import torchvision.transforms as transforms + +from load import loadnet + + +if __name__ == '__main__': + # 原因: 如 https://blog.csdn.net/xiemanR/article/details/71700531 + # test loader涉及多线程操作, 在windows环境下需要用__name__ == '__main__'包装 + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + + testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test) + testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + + [net, acc] = loadnet(1) + print ('Expected accuracy: %f%%' % acc) + + net.eval() # 变为测试模式, 对dropout和batch normalization有影响 + correct = 0 + total = 0 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + with torch.no_grad(): # 运算不需要进行求导, 提高性能 + for batch_idx, (inputs, targets) in enumerate(testloader): + inputs, targets = inputs.to(device), targets.to(device) + outputs = net(inputs) + _, predicted = outputs.max(1) + total += targets.size(0) + correct += predicted.eq(targets).sum().item() + print(batch_idx, len(testloader)) + print ('Caculated accuracy: %f%%' % (float(correct) / total)) diff --git a/utils.py b/utils.py index 4c9b3f90c..3b0e72ea4 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ ''' import os import sys +import platform import time import math @@ -42,8 +43,28 @@ def init_params(net): init.constant(m.bias, 0) -_, term_width = os.popen('stty size', 'r').read().split() -term_width = int(term_width) +sysstr = platform.system() +if (sysstr == "Windows"): + from ctypes import windll, create_string_buffer + # stdin handle is -10 + # stdout handle is -11 + # stderr handle is -12 + h = windll.kernel32.GetStdHandle(-12) + csbi = create_string_buffer(22) + res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi) + if res: + import struct + (bufx, bufy, curx, cury, wattr, left, top, right, bottom, maxx, maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw) + term_width = right - left + 1 + else: + term_width = 100 # can't determine actual size - return default values +else: + try: + _, term_width = os.popen('stty size', 'r').read().split() + term_width = int(term_width) + except: + print("Cannot load stty information") + term_width = 100 TOTAL_BAR_LENGTH = 65. last_time = time.time()