zoukankan      html  css  js  c++  java
  • 通过迭代器获取数据

    %pylab inline
    from keras.datasets import mnist
    import mxnet as mx
    from mxnet import nd
    from mxnet import autograd 
    import random
    from mxnet import gluon
    
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    num_examples = x_train.shape[0]
    num_inputs = x_train.shape[1] * x_train.shape[2]
    batch_size = 64
    

    1. 自定义数据迭代器

    def data_iter1(X, Y, batch_size):
        num_samples = X.shape[0]
        idx = list(range(num_samples))
        random.shuffle(idx)
        
        X = nd.array(X)
        Y = nd.array(Y)
        for i in range(0, num_examples, batch_size):
            j = nd.array(idx[i: min(i + batch_size, num_examples)])
            yield nd.take(X, j), nd.take(Y, j)
    

    2. Gluon 迭代器

    dataset = gluon.data.ArrayDataset(x_train, y_train)
    data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)
    

    3. 从迭代器中获取数据

    for data, label in data_iter:
        print(data.shape, label.shape)
        break
    
    (64, 28, 28) (64,)
    
    for data, label in data_iter1(x_train, y_train, batch_size):
        print(data.shape, label.shape)
        break
    
    (64, 28, 28) (64,)
    

    更多精彩见:使用 迭代器 获取 Cifar 等常用数据集

  • 相关阅读:
    并查集
    结构体字节对齐
    Dijkstra算法(单源最短路径)
    图的遍历
    二叉树的非递归遍历
    浅谈C语言中的联合体
    二叉排序(查找)树
    KMP算法
    C语言文件操作解析(四)
    Trie树
  • 原文地址:https://www.cnblogs.com/q735613050/p/8367173.html
Copyright © 2011-2022 走看看