zoukankan      html  css  js  c++  java
  • tensorflow学习014——tf.data运用实例

    3.2tf.data运用实例

    使用tf.data作为输入,改写之前写过的MNIST代码

    点击查看代码
    import tensorflow as tf
    #下载数据集
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
    #对图片数据进行归一化
    train_images = train_images / 255
    test_images = test_images / 255
    
    ds_train_images = tf.data.Dataset.from_tensor_slices(train_images)
    ds_train_labels = tf.data.Dataset.from_tensor_slices(train_labels)
    #zip到一起,为了后面的shuffle,否则image与label的会对应错误
    ds_train = tf.data.Dataset.zip((ds_train_images,ds_train_labels))
    
    ds_train  = ds_train.shuffle(10000).repeat().batch(4)
    
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28,28)),
        tf.keras.layers.Dense(128,activation='relu'),
        tf.keras.layers.Dense(10,activation= 'softmax')
    ])
    model.compile(optimizer = 'adam',
                  loss= 'sparse_categorical_crossentropy',
                  metrics = ['accuracy'])
    ds_test = tf.data.Dataset.from_tensor_slices((test_images,test_labels))
    ds_test = ds_test.batch(4)
    steps_per_epoch = train_images.shape[0] / 4 #表明每轮训练多少步,这是因为上面对dataser进行了repeat()所以需要指定每一轮训练多少步
    model.fit(ds_train,epochs=10,steps_per_epoch=steps_per_epoch,validation_data=ds_test) 
    
    


    作者:孙建钊
    出处:http://www.cnblogs.com/sunjianzhao/
    本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接,否则保留追究法律责任的权利。

  • 相关阅读:
    CSP-S2019游记
    SOJ 一句话题解整理
    CF547E Mike and Friends
    CF506E Mr. Kitayuta's Gift
    在windows环境下安装和使用Python(CPython)
    GeekGame2020_部分WriteUp
    php通过curl传输JSON对象
    Invalid datetime format: 1292 Incorrect datetime value
    解决 select2 开启 tags 不能输入中文的问题
    electerm 设置同步(Setting sync)
  • 原文地址:https://www.cnblogs.com/sunjianzhao/p/15581513.html
Copyright © 2011-2022 走看看