zoukankan      html  css  js  c++  java
  • 08 Dataset and Dataloader

    Mini-Batch

    为什么需要有 Batch_Size 这个参数

    代码实现

    模型仍为上一讲中多维输入的模型,损失和优化器亦不变

    from abc import ABC
    
    import numpy as np
    import torch
    from torch.utils.data import Dataset  # Dataset是抽象类, 不能实例化
    from torch.utils.data import DataLoader
    
    
    class DiabetesDataset(Dataset):
        """
        继承自抽象类Dataset
        """
        def __init__(self, filepath):
            """
            :param filepath: 文件路径
            """
            xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
            self.len = xy.shape[0]  # 比如我们的数据集为N行M列, shape[0]取第0列的维度, 也即N. 方便__len__的构造
            self.x_data = torch.from_numpy(xy[:, :-1])
            self.y_data = torch.from_numpy(xy[:, [-1]])
    
        def __getitem__(self, index):
            """
            :param index: 数据下标
            :return: 获取下标对应的数据
            """
            return self.x_data[index], self.y_data[index]  # 直接返回xy对应下标的元素即可, 返回值是一个元组
    
        def __len__(self):
            """
            :return: 反馈数据集中的数据条数
            """
            return self.len
    
    
    dataset = DiabetesDataset('diabetes.csv.gz')
    train_loader = DataLoader(dataset=dataset,  # dataset即数据集
                              batch_size=2,     # batch_size为小批量的容量大小
                              shuffle=True,     # shuffle决定是否打乱数据, True打乱
                              num_workers=2)    # 加载数据的线程数码你
    
    
    class Model(torch.nn.Module, ABC):
        def __init__(self):
            super(Model, self).__init__()
            self.linear1 = torch.nn.Linear(8, 6)  # 维度 8 -> 6 -> 4 -> 1
            self.linear2 = torch.nn.Linear(6, 4)
            self.linear3 = torch.nn.Linear(4, 1)
            self.sigmoid = torch.nn.Sigmoid()
    
        def forward(self, x):
            x = self.sigmoid(self.linear1(x))
            x = self.sigmoid(self.linear2(x))
            x = self.sigmoid(self.linear3(x))
            return x
    
    
    model = Model()
    criterion = torch.nn.BCELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    
    if __name__ == '__main__':  # 使用多线程的时候需要用if语句包裹起来, 否则会报RuntimeError
        for epoch in range(100):
            
            for i, data in enumerate(train_loader, 0):
                """
                使用enumerate是为了获取当前是第几[i]次迭代, train_loader中取出的(x, y)给data
                因为DataLoader我们设定了batch_size为2, 而我们的训练集中共有759条数据, 所以每个epoch都需要做i: 0~379的迭代
                """
                inputs, labels = data
                y_pred = model(inputs)
                loss = criterion(y_pred, labels)
                print(epoch, i, loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    

    epoch0, epoch99


    Reference

    https://www.bilibili.com/video/BV1Y7411d7Ys?p=8

  • 相关阅读:
    UVA 11488 Hyper Prefix Sets (字典树)
    UVALive 3295 Counting Triangles
    POJ 2752 Seek the Name, Seek the Fame (KMP)
    UVA 11584 Partitioning by Palindromes (字符串区间dp)
    UVA 11100 The Trip, 2007 (贪心)
    JXNU暑期选拔赛
    计蒜客---N的-2进制表示
    计蒜客---线段的总长
    计蒜客---最大质因数
    JustOj 2009: P1016 (dp)
  • 原文地址:https://www.cnblogs.com/vict0r/p/13618398.html
Copyright © 2011-2022 走看看