zoukankan      html  css  js  c++  java
  • pytorch-第二章神经网络实战及回归任务-mnist数据集分类

    第一步: 进行数据的读取,使用request进行网址内容的抓取 

    from pathlib import Path
    import requests
    
    DATA_PATH = Path('data')
    Path = DATA_PATH / "mnist"
    
    Path.mkdir(parents=True, exist_ok=True)
    
    URL = "http://deeplearning.net/data/mnist/"
    FILENAME = "mnist.pkl.gz"
    
    if not (Path / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (Path / FILENAME).open("wb").write(content)
    
    
    import pickle
    import gzip
    
    
    with gzip.open((Path / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    
    
    from matplotlib import pyplot as plt
    import numpy as np
    
    plt.imshow(x_train[0].reshape(28, 28), cmap='gray')
    plt.show()

     第二步: 使用TensorDataset 和 Dataloader 来构造队列数据集,损失值使用的F.cross_entropy, 最后再进行网络的优化和效果显示 

    import torch
    from torch import nn
    
    
    x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
    
    import torch.nn.functional as F
    
    loss_func = F.cross_entropy
    
    def model(xb):
        return xb.mm(weights) + biases
    
    step = 20
    bs = 64
    xb = x_train[0:bs]
    yb = y_train[0:bs]
    
    weights = torch.randn([784, 10], dtype=torch.float, requires_grad=True)
    biases = torch.randn(10, dtype=torch.float, requires_grad=True)
    
    
    # 定义好模型
    class Mnist_nn(nn.Module):
        def __init__(self):
            super().__init__()
            self.hidden1 = nn.Linear(784, 128)
            self.hidden2 = nn.Linear(128, 256)
            self.out = nn.Linear(256, 10)
        def forward(self, x):
            x = F.relu(self.hidden1(x))
            x = F.relu(self.hidden2(x))
            x = self.out(x)
            return x
    
    def loss_batch(model, xx, yy, loss_func, opt=None):
        loss = loss_func(model(xx), yy)
        if opt != None:
            loss.backward()
            opt.step()
            opt.zero_grad()
        return loss.item(), len(xx)
    
    def fit(step, model, loss_func, train_dl, test_dl, opt):
        for i in range(step):
            model.train()
            for xx, yy in train_dl:
                loss, num = loss_batch(model, xx, yy, loss_func, opt)
            model.eval()
            with torch.no_grad():
                loss, num = zip(*[loss_batch(model, xx, yy, loss_func) for xx, yy in test_dl])
            print("验证集的损失值是:", np.sum(np.multiply(loss, num)) / np.sum(num))
        # 进行预测结果的显示
        predict = np.argmax(model(torch.tensor(x_valid[0], dtype=torch.float)).data.numpy(), axis=0)
    
        plt.imshow(x_valid[0].reshape(28, 28), cmap='gray')
        plt.show()
        print("当前的预测结果是", predict)
    
    def get_model():
        model = Mnist_nn() #实例化模型
        return model, torch.optim.SGD(model.parameters(), lr=0.001)
    
    
    # 数据构造
    from torch.utils.data import TensorDataset
    from torch.utils.data import DataLoader
    
    train_ds = TensorDataset(x_train, y_train)
    train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
    
    test_ds = TensorDataset(x_valid, y_valid)
    test_dl = DataLoader(test_ds, batch_size=bs, shuffle=True)
    
    
    model_mnist, optim = get_model()
    fit(step, model_mnist, loss_func, train_dl, test_dl, optim)

  • 相关阅读:
    Entity Framework 6 Recipes 2nd Edition(9-4)译->Web API 的客户端实现修改跟踪
    Entity Framework 6 Recipes 2nd Edition(9-3)译->找出Web API中发生了什么变化
    Entity Framework 6 Recipes 2nd Edition(9-2)译->用WCF更新单独分离的实体
    Entity Framework 6 Recipes 2nd Edition(9-1)译->用Web Api更新单独分离的实体
    jar包和war包的介绍和区别
    软件设计-高内聚耦合(转)
    Android中的dp,px以及wrap_content的实际展示效果
    Eclipse编辑器样式修改
    对TextView设置drawable,用setCompoundDrawables方法实现
    Android调用本机应用市场,实现应用评分功能
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/12667771.html
Copyright © 2011-2022 走看看