zoukankan      html  css  js  c++  java
  • Python 加载mnist、cifar数据


    import
    tensorflow.examples.tutorials.mnist.input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

    1、加载mnist数据

    执行完成后,会在当前目录下新建一个文件夹MNIST_data, 下载的数据将放入这个文件夹内。下载的四个文件为:

    下载下来的数据集被分三个子集:5.5W行的训练数据集(mnist.train),5千行的验证数据集(mnist.validation)和1W行的测试数据集(mnist.test)。因为每张图片为28x28的黑白图片,所以每行为784维的向量。

    print (mnist.train.images.shape)
    print (mnist.train.labels.shape)
    print (mnist.validation.images.shape)
    print (mnist.validation.labels.shape)
    print (mnist.test.images.shape)
    print (mnist.test.labels.shape)

    (55000, 784)
    (55000, 10)
    (5000, 784)
    (5000, 10)
    (10000, 784)


    (10000, 10)

    在训练过程中可以按批次获取

    from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
    
    mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集
    X_mb, _ = mnist.train.next_batch(128)
    print(X_mb.shape)

    Extracting ../../MNIST_data rain-images-idx3-ubyte.gz
    Extracting ../../MNIST_data rain-labels-idx1-ubyte.gz
    Extracting ../../MNIST_data 10k-images-idx3-ubyte.gz
    Extracting ../../MNIST_data 10k-labels-idx1-ubyte.gz
    (128, 784)

     2、加载cifar数据

    import torch
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    def load_data_CIFAR10():
        train_dataset = dsets.CIFAR10(root='./data/', train=True,download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
        return train_loader
    train_loader = load_data_CIFAR10()

    Using downloaded and verified file: ./data/cifar-10-python.tar.gz
    Extracting ./data/cifar-10-python.tar.gz to ./data/

    cifar-10  训练集和测试集分别有50000和10000张图片,RGB3通道,尺寸32×32, 

    一个样本由3037个字节组成,其中第一个字节是label,剩余3036(32*32*3)个字节是image,每个文件由连续的10000个样本组成,打开文件,发现是一堆二进制数据

    https://www.cnblogs.com/denny402/p/5852689.html

  • 相关阅读:
    【蓝桥杯/算法训练】Sticks 剪枝算法 (附胜利大逃亡)
    【蓝桥杯/基础练习】回文数、特殊的回文数
    【蓝桥杯/基础练习】十六进制转八进制
    交叉验证
    第一次写博客---交叉验证
    实验五
    汇编语言第二章
    实验四
    实验三
    实验二
  • 原文地址:https://www.cnblogs.com/gaona666/p/12349751.html
Copyright © 2011-2022 走看看