#!/usr/bin/env python from lenet import LeNet5 import torch import torch.nn as nn import torch.optim as optim from torchvision.datasets.mnist import MNIST import torchvision.transforms as transforms from torch.utils.data import DataLoader import onnx def train(epoch): global cur_batch_win net.train() loss_list, batch_list = [], [] for i, (images, labels) in enumerate(data_train_loader): optimizer.zero_grad() output = net(images) loss = criterion(output, labels) loss_list.append(loss.detach().cpu().item()) batch_list.append(i+1) if i % 10 == 0: print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) loss.backward() optimizer.step() def test(): net.eval() total_correct = 0 avg_loss = 0.0 for i, (images, labels) in enumerate(data_test_loader): output = net(images) avg_loss += criterion(output, labels).sum() pred = output.detach().max(1)[1] total_correct += pred.eq(labels.view_as(pred)).sum() avg_loss /= len(data_test) print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) data_train = MNIST('./data', download=True, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor()])) data_test = MNIST('./data', train=False, download=True, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor()])) data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8) data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) net = LeNet5() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=2e-3) cur_batch_win = None cur_batch_win_opts = { 'title': 'Epoch Loss Trace', 'xlabel': 'Batch Number', 'ylabel': 'Loss', 'width': 1200, 'height': 600, } for epoch in range(1, 16): train(epoch) test() dummy_input = torch.randn(1, 1, 32, 32, requires_grad=True) torch.onnx.export(net, dummy_input, 'lenet.onnx') onnx_model = onnx.load('lenet.onnx') onnx.checker.check_model(onnx_model)