zoukankan      html  css  js  c++  java
  • 第一个深度学习网络(别人的)

    import numpy as np
    import torch
    from torchvision.datasets import mnist
    import  torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    import torch.nn.functional as F
    import torch.optim as optim
    from torch import  nn
    from matplotlib import pyplot as plt
    
    #定义参数
    train_batch_size=64
    test_batch_size=128
    learning_rate=0.01
    num_epoches=3
    lr=0.01
    momentum=0.5
    
    #定义预处理函数,这些预处理依次放在Compose中
    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#Compose将多个tranform组合,ToTensor转换形状,Normalize将会把Tensor正则化。
    #下载数据,并对数据进行预处理
    train_dataset=mnist.MNIST('./data',train=True,transform=transform,download=True)
    test_dataset=mnist.MNIST('./data',train=False,transform=transform)
    #data_loader是一个可迭代对象,可以当迭代器使用
    train_loader=DataLoader(train_dataset,batch_size=train_batch_size,shuffle=True)
    test_loader=DataLoader(test_dataset,batch_size=test_batch_size,shuffle=False)
    
    examples=enumerate(test_loader)
    batch_idx,(example_data,example_targets)=next(examples)
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(example_data[i][0],cmap='gray',interpolation='none')
    #plt.show()
    #构建网络
    class Net(nn.Module):
        def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
            super(Net,self).__init__()
            self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNorm1d(n_hidden_1))#Sequential是将网络的层组合到一起
            self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
            self.layer1 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))
        def forward(self,x):
            x=F.relu(self.layer1(x))#将ReLU层添加到网络
            x = F.relu(self.layer2(x))
            x = self.layer1(x)
            return x
    #实例化网络
     #检查网络是否有GPU,有则使用,无则使用cpu
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model=Net(28*28,10,784,10)
    model.to(device)
     #定义损失函数和优化器
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.SGD(model.parameters(),lr=lr,momentum=momentum)
    
    #开始训练
    losses=[]
    acces=[]
    eval_losses=[]
    eval_acces=[]
    for epoch in range(num_epoches):
        train_loss=0
        train_acc=0
        model.train()
        #动态修改参数学习率
        if epoch%5==0:
            optimizer.param_groups[0]['lr']*=0.1
        for img,label in train_loader:
            img=img.to(device)
            label=label.to(device)
            img=img.view(img.size(0),-1)
            #前向传播
            out=model(img)
            loss=criterion(out,label)
            #f反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #记录误差
            train_loss+=loss.item()
            #计算分类的准确率
            _,pred=out.max(1)
            num_correct=(pred==label).sum().item()
            acc=num_correct/img.shape[0]
            train_acc+=acc
    
        losses.append(train_loss/len(train_loader))
        acces.append(train_acc/len(train_loader))
        eval_loss=0
        eval_acc=0
        model.eval()
        for img,label in test_loader:
            img=img.to(device)
            label=label.to(device)
            img=img.view(img.size(0),-1)
            out=model(img)
            loss=criterion(out,label)
            #记录误差
            eval_loss+=loss.item()
            #记录准确李
            _,pred=out.max(1)
            num_correct=(pred==label).sum().item()
            acc=num_correct/img.shape[0]
            eval_acc+=acc
    
        eval_losses.append(eval_loss / len(test_loader))
        eval_acces.append(eval_acc / len(test_loader))
        print('epoch:{},Train Loss:{:.4f},Train Acc:{:.4f},Test Loss:{:.4f},Test Acc:{:.4f}'
              .format(epoch,train_loss/len(train_loader),train_acc/len(train_loader),eval_loss / len(test_loader),eval_acc / len(test_loader)))

    体验自己开发得深度学习乐趣

  • 相关阅读:
    jenkins 邮件配置
    jenkins+git学习笔记
    用户定义的变量+HTTP Cookie 管理器组合实现接口关联+问题处理
    jmeter参数化实现之CSV Data Set Config
    Jmeter学习笔记
    除法应用遇到的问题-类型及小数点
    python2输出中文乱码问题
    python常见函数及方法
    数据库的基本操作
    使用eclipse搭建maven项目
  • 原文地址:https://www.cnblogs.com/gao109214/p/13858122.html
Copyright © 2011-2022 走看看