zoukankan      html  css  js  c++  java
  • tensorflow(十七):数据的加载:map()、shuffle()、tf.data.Dataset.from_tensor_slices()

    一、数据集简介

     

    二、MNIST数据集介绍

     三、CIFAR 10/100数据集介绍

     

     四、tf.data.Dataset.from_tensor_slices()

     五、shuffle()随机打散

     六、map()数据预处理

     

     

     七、实战

    import tensorflow as tf
    import tensorflow.keras as keras
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    def prepare_mnist_features_and_labels(x,y):
        x = tf.cast(x, tf.float32) / 255.0
        y = tf.cast(y, tf.int64)
        return x,y
    
    def mnist_dataset():
        (x,y), (x_test,y_test) = keras.datasets.fashion_mnist.load_data() #numpy中的格式
    
        y = tf.one_hot(y, depth=10)                     #[10k] ==> [10k,10]的tensor
        y_test = tf.one_hot(y_test, depth=10)
    
        ds = tf.data.Dataset.from_tensor_slices((x,y))
        ds = ds.map(prepare_mnist_features_and_labels)  #数据预处理,注意:tf.map中传进的参数
        ds = ds.shuffle(60000).batch(100)               #随机打散,读取一个batch的样本
    
        ds_val = tf.data.Dataset.from_tensor_slices((x_test,y_test))
        ds_val = ds_val.map(prepare_mnist_features_and_labels)
        ds_val = ds_val.shuffle(10000).batch(100)
        return ds, ds_val
    
    
    def main():
        ds, ds_val = mnist_dataset()
    
        print("训练集信息如下:")
        iteration_ds = iter(ds)
        iter_ds = next(iteration_ds)
        print(iter_ds[0].shape, iter_ds[1].shape)
    
        print("测试集信息如下:")
        iteration_ds_val = iter(ds_val)
        iter_ds_val = next(iteration_ds_val)
        print(iter_ds_val[0].shape, iter_ds_val[1].shape)
    
    if __name__ == '__main__':
        main()

     

  • 相关阅读:
    shell getopt getopts获取参数
    apache+svn+ladp认证
    SVN 迁移项目分支
    iptables 优先级
    很实用的一篇HTTP状态码
    套路还在——矩阵计算估值
    CU上看到的一个简单的算法帖子
    linux下服务端实现公网数据转发
    c++接口实现与分离(转载)
    c++继承概念
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14612327.html
Copyright © 2011-2022 走看看