zoukankan      html  css  js  c++  java
  • 利用torch.nn实现前馈神经网络解决 多分类 任务

    1 导入包

    from torchvision import datasets,transforms
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt
    import torchvision
    from torch import nn
    import numpy as np
    import torch

    2 加载数据集

    transformation =transforms.Compose([
        transforms.ToTensor()  #转换到Tensor,并且转换为0-1之间,将channel 放到第一个纬度
    ])
    train_ds = datasets.MNIST('data/',train = True,transform = transformation,download = True)
    test_ds = datasets.MNIST('data/',train = False,transform = transformation,download = True)
    # len(train_ds)
    # len(test_ds)

    3 加载loader

    train_loader = DataLoader(train_ds,batch_size =64 ,shuffle = True,num_workers = 16)
    test_loader = DataLoader(test_ds,batch_size =256 ,shuffle = False,num_workers = 16)

    4 创建模型

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear1 = nn.Linear(28*28,128)
            self.linear2 = nn.Linear(128,64)
            self.linear3 = nn.Linear(64,10)
        def forward(self,input):
            x = input.view(-1,28*28)
            x = nn.functional.relu(self.linear1(x))
            x = nn.functional.relu(self.linear2(x))
            y = self.linear3(x)
            return y

    5 模型、损失函数、优化器设置

    model = Model()
    loss_fn  = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

    6 精度、测试集精度和损失

    def accuracy(y_pred,y_true):
        y_pred = torch.argmax(y_pred,dim=1)
        acc = (y_pred==y_true).float().mean()
        return acc
    #测试集
    def evaluate_testset(data_loader,model):
        acc_sum,loss_sum,total_example = 0.0,0.0,0
        for x,y in data_loader:
            y_hat = model(x)
            acc_sum += (y_hat.argmax(dim=1)==y).sum().item()
            loss = loss_fn(y_hat,y) 
            loss_sum += loss.item()
            total_example+=y.shape[0]
        return acc_sum/total_example,loss_sum

    7 定义训练函数

    #定义模型训练函数
    def train(model,train_loader,test_loader,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
        train_ls = []
        test_ls = []
        for epoch in range(num_epochs): # 训练模型一共需要num_epochs个迭代周期
            train_loss_sum, train_acc_num,total_examples = 0.0,0.0,0
            for x, y in train_loader: # x和y分别是小批量样本的特征和标签
                y_pred = model(x)
                loss = loss_fn(y_pred, y)  #计算损失
                optimizer.zero_grad() # 梯度清零
                loss.backward()  # 反向传播
                optimizer.step() #梯度更新
                total_examples += y.shape[0]
                train_loss_sum += loss.item()
                train_acc_num += (y_pred.argmax(dim=1)==y).sum().item()
            train_ls.append(train_loss_sum)
            test_acc,test_loss = evaluate_testset(test_loader,model)
            test_ls.append(test_loss)
            print('epoch %d, train_loss %.6f,test_loss %f,train_acc %.6f,test_acc %.6f'%(epoch+1, train_ls[epoch],test_ls[epoch],train_acc_num/total_examples,test_acc))
        return 

    8 超参数设置

    num_epoch  = 20
    batch_size = 64

    9 模型训练

    train(model,train_loader,test_loader,loss_fn,num_epoch,batch_size,params=model.parameters,lr=0.001,optimizer=optimizer)

    因上求缘,果上努力~~~~ 作者:希望每天涨粉,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15505314.html

  • 相关阅读:
    Eclipse复制或修改项目后,把项目部署后发现还是原来的项目名称
    eclipse设置新建jsp文件默认字符编码为utf-8
    mysql数据库无法插入中文字符
    Dos中查看mysql数据时 中文乱码
    spring 解决中文乱码问题
    mysql再次安装问题
    The import javax.servlet cannot be resolved
    eclipse快捷键补全
    eclipse自动补全
    hibernate运行常见错误
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15505314.html
Copyright © 2011-2022 走看看