zoukankan      html  css  js  c++  java
  • pytorch的Dataloader的shuffle

    https://blog.csdn.net/qq_20200047/article/details/105671374

    1.简单测import sys

    import sys
    import torch
    import random
    import argparse
    import numpy as np
    import pandas as pd
    import torch.nn as nn
    from torch.nn import functional as F
    from torch.optim import lr_scheduler
    from torchvision import datasets, transforms
    from torch.utils.data import TensorDataset, DataLoader, Dataset
     
    class DealDataset(Dataset):
        def __init__(self):
            xy = np.random.randn(4,3)
            print(xy)
            print("
    ")
            self.x_data = torch.from_numpy(xy[:, 0:-1])
            self.y_data = torch.from_numpy(xy[:, [-1]])
            self.len = xy.shape[0]
        
        def __getitem__(self, index):
            return self.x_data[index], self.y_data[index]
     
        def __len__(self):
            return self.len
       
    dealDataset = DealDataset()
     
    train_loader2 = DataLoader(dataset=dealDataset,
                              batch_size=2,
                              shuffle=True)
    for j in range(5):
        for i, data in enumerate(train_loader2):
            inputs, labels = data
    
            #inputs, labels = Variable(inputs), Variable(labels)
            print(inputs)
        print("
    ")
            #print("epoch:", epoch, "的第" , i, "个inputs", inputs.data.size(), "labels", labels.data.size())

    输出:

    [[ 1.35870858 -0.74676435 -0.4181123 ]
     [ 0.14165115 -1.55553785 -2.03821185]
     [ 0.46154706  1.36100343 -0.13686081]
     [ 0.59683626  1.60361944 -0.90266193]]
    
    
    tensor([[ 0.4615,  1.3610],
            [ 0.1417, -1.5555]], dtype=torch.float64)
    tensor([[ 0.5968,  1.6036],
            [ 1.3587, -0.7468]], dtype=torch.float64)
    
    
    tensor([[ 1.3587, -0.7468],
            [ 0.4615,  1.3610]], dtype=torch.float64)
    tensor([[ 0.5968,  1.6036],
            [ 0.1417, -1.5555]], dtype=torch.float64)
    
    
    tensor([[ 0.1417, -1.5555],
            [ 1.3587, -0.7468]], dtype=torch.float64)
    tensor([[0.5968, 1.6036],
            [0.4615, 1.3610]], dtype=torch.float64)
    
    
    tensor([[ 0.5968,  1.6036],
            [ 1.3587, -0.7468]], dtype=torch.float64)
    tensor([[ 0.4615,  1.3610],
            [ 0.1417, -1.5555]], dtype=torch.float64)
    
    
    tensor([[ 0.5968,  1.6036],
            [ 0.1417, -1.5555]], dtype=torch.float64)
    tensor([[ 1.3587, -0.7468],
            [ 0.4615,  1.3610]], dtype=torch.float64)

    说明每次调用dataloader都是重新打乱,而不是在定义的时候只打乱一次。

  • 相关阅读:
    Python3 实现一个简单的TCP 客户端
    Mac 下 安装 和 使用 Go 框架 Beego
    Go 操作文件及文件夹 os.Mkdir及os.MkdirAll两者的区别
    Go gin 之 Air 实现实时加载
    Mac os 配置常用alias
    Mac 下 MAMP配置虚拟主机
    Thinkphp5 项目部署至linux服务器报403错误
    Linux 安装最新版 node.js 之坑
    Mac item2如何使用rz sz 上传下载命令
    Mac 使用 iTerm2 快捷登录远程服务器
  • 原文地址:https://www.cnblogs.com/BlueBlueSea/p/13853796.html
Copyright © 2011-2022 走看看