zoukankan      html  css  js  c++  java
  • 多层全连接神经网络实现minist手写数字分类

    import torch 
    import numpy as np
    import torch.nn as nn
    from torch.autograd import Variable
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    batch_size = 64
    learning_rate = 1e-2
    num_epoches = 20
    data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
    #transform.Compose() 将各种预处理操作组合在一起
    #transform.ToTensor() 将数据转化为Tensor类型,并自动标准化,Tensor的取值是(0,1)
    #transform.Normalize()是标准化操作,类似正太分布的标准化,第一个值是均值,第二个值是方差
    #如果图像是三个通道,则transform.Normalize([a,b,c],[d,e,f])
    train_dataset = datasets.MNIST(root = './mnist_data', train = True, transform = data_tf, download = True) #用datasets加载数据集,传入预处理
    test_dataset  = datasets.MNIST(root = './mnist_data', train = False,transform = data_tf)
    train_loader  = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)    #利用DataLoader建立一个数据迭代器
    test_loader   = DataLoader(test_dataset,  batch_size = batch_size, shuffle = False)
    class Batch_Net(nn.Module):
        def __init__(self, inputdim, hidden1, hidden2, outputdim):
            super(Batch_Net, self).__init__()
            self.layer1 = nn.Sequential(nn.Linear(inputdim, hidden1), nn.BatchNorm1d(hidden1), nn.ReLU(True))
            self.layer2 = nn.Sequential(nn.Linear(hidden1, hidden2), nn.BatchNorm1d(hidden2), nn.ReLU(True))
            self.layer3 = nn.Sequential(nn.Linear(hidden2, outputdim))
        
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)   
            return x
    model = Batch_Net(28*28, 300, 100, 10)
    model

    定义损失函数和优化器 

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr = learning_rate)

    训练模型

    for epoch in range(num_epoches):
        train_loss = 0
        train_acc = 0
        model.train()   #这句话会自动调整batch_normalize和dropout值,很关键!
        for img, label in train_loader:      
            img = img.view(img.size(0), -1)   #将数据扁平化为一维
            img = Variable(img)
            label = Variable(label)
            # 前向传播
            out = model(img)
            loss = criterion(out, label)
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 记录误差
            train_loss += loss.item()
            # 计算分类的准确率
            _, pred = out.max(1)
            num_correct = (pred == label).sum().item()
            acc = num_correct / img.shape[0]
            train_acc += acc
    
        print('epoch:{},train_loss:{:.6f},acc:{:.6f}'.format(epoch+1, train_loss/len(train_loader), train_acc/len(train_loader)))
    epoch:1,train_loss:0.002079,acc:0.999767
    ......  
    epoch:19,train_loss:0.001532,acc:0.999917
    epoch:20,train_loss:0.001670,acc:0.999850

    测试集

    model.eval()  #在评估模型时使用,固定BN 和 Dropout
    eval_loss = 0
    val_acc  = 0
    for img , label in test_loader:
        img = img.view(img.size(0), -1)
        img = Variable(img, volatile = True)   #volatile=TRUE表示前向传播是不会保留缓存,因为测试集不需要反向传播
        label = Variable(label, volatile = True)   
        out = model(img)
        loss = criterion(out, label)
        eval_loss += loss.item()
        _,pred = torch.max(out, 1)
        num_correct = (pred == label).sum().item()
        print(num_correct)
        eval_acc = num_correct / label.shape[0]
        val_acc += eval_acc
        
    print('Test Loss:{:.6f}, Acc:{:.6f}'.format(eval_loss/len(test_loader), val_acc/len(test_loader)))
        
    Test Loss:0.062413, Acc:0.981091

     

  • 相关阅读:
    大数据学习系列之九---- Hive整合Spark和HBase以及相关测试
    基于SpringBoot开发一个Restful服务,实现增删改查功能
    大数据学习系列之八----- Hadoop、Spark、HBase、Hive搭建环境遇到的错误以及解决方法
    一个两年java程序猿的2017个人总结
    大数据学习系列之七 ----- Hadoop+Spark+Zookeeper+HBase+Hive集群搭建 图文详解
    大数据学习系列之六 ----- Hadoop+Spark环境搭建
    大数据学习系列之五 ----- Hive整合HBase图文详解
    大数据学习系列之四 ----- Hadoop+Hive环境搭建图文详解(单机)
    使用VMware安装linux虚拟机以及相关配置
    大数据学习系列之三 ----- HBase Java Api 图文详解
  • 原文地址:https://www.cnblogs.com/zgqcn/p/10850507.html
Copyright © 2011-2022 走看看