zoukankan      html  css  js  c++  java
  • tensorflow-- Dataset创建数据集对象

    tf.data模块包含:

    •  experimental 模块
    •  Dataset 类
    •  FixedLengthRecordDataset 类
    • TFRecordDataset 类
    • TextLineDataset 类
    
    
     1 #  author by FH.
     2 #  OverView:
     3 #  tf.data
     4 #           experimental  ---Modules
     5 #           Dataset      ---class
     6 #           FixedLengthRecordDataset  ---class
     7 #           TFRecordDataset           ---class
     8 #           TextLineDataset           ---class
     9 import tensorflow as tf
    10 import numpy as np
    11 
    12 
    13 # 1. 使用静态方法 tf.data.Dataset.from_tensor_slices
    14 #       将输入的第一个维度切割,形成dataset
    15 # 2. 使用 Dataset的 make_one_shot_iterator() 实例化一个 iterator
    16 #       这个iterator 只能从头到尾读取一次。“one shot iterator”
    17 def test1():
    18     sess = tf.Session()
    19     dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0,2.0,3.0,4.0,5.0]))
    20     dataset2 = tf.data.Dataset.from_tensor_slices(np.array([[1,2],[3,4],[0,9]]))
    21     dataset3 = tf.data.Dataset.from_tensor_slices(
    22         {
    23             "a":np.array([1.0,2,3,4,5.0]),
    24             "b":np.random.uniform(size=(5,2))
    25         }
    26     )
    27     # 使用 Dataset的 make_one_shot_iterator() 实例化一个 iterator
    28     #     这个iterator 只能从头到尾读取一次。“one shot iterator”
    29     oneShotIterator1 = dataset1.make_one_shot_iterator()
    30     oneShotIterator2 = dataset2.make_one_shot_iterator()
    31     oneShotIterator3 = dataset3.make_one_shot_iterator()
    32     element1 = oneShotIterator1.get_next()
    33     element2 = oneShotIterator2.get_next()
    34     element3 = oneShotIterator3.get_next()
    35     for i in range(5):
    36         print(sess.run(element1))
    37     for i in range(3):
    38         print(sess.run(element2))
    39     for i in range(5):
    40         print(sess.run(element3))
    41     sess.close()
    42 
    43 # 1.Dataset 中的数据元素转换。
    44 #           map() :参数为一个函数,将dataset中的每个元素带入获取新的值
    45 #           batch(): 参数为一个整数,将多个元素组合成一个batch
    46 def test2():
    47     sess = tf.Session()
    48     dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0,6]))
    49     # map() 重新映射新的元素值
    50     dataset1 = dataset.map(lambda x: x * 3)
    51     # batch()  2个组成一个batch, 组成batch 之后size 为3
    52     dataset2 = dataset.batch(2)
    53     # shuffle() 打乱dataset
    54     dataset3 = dataset.shuffle(buffer_size=3)
    55     # repeat()  将整个序列重复多次,重复4次 size 为24
    56     dataset4 = dataset.repeat(4)
    57 
    58     oneShotIterator1 = dataset1.make_one_shot_iterator()
    59     oneShotIterator2 = dataset2.make_one_shot_iterator()
    60     oneShotIterator3 = dataset3.make_one_shot_iterator()
    61     oneShotIterator4 = dataset4.make_one_shot_iterator()
    62     element1 = oneShotIterator1.get_next()
    63     element2 = oneShotIterator2.get_next()
    64     element3 = oneShotIterator3.get_next()
    65     element4 = oneShotIterator4.get_next()
    66     for i in range(6):  # map()
    67         print(sess.run(element1))
    68     for i in range(3):  # batch()
    69         print(sess.run(element2))
    70     for i in range(6):  # shuffle()
    71         print(sess.run(element3))
    72     for i in range(24): # repeat()
    73         print(sess.run(element4))
    74     sess.close()
    75 
    76 # example1: 读取图片和相应的标签并打乱,组成
    77 #          batch_size=2 的数据集,重复10 epoch
    78 def _parse_function(imgfilename,label):
    79     image_value = tf.read_file(imgfilename)
    80     img = tf.image.decode_image(image_value)
    81     img = tf.image.resize_images(img,[256,256])
    82     return img,label
    83 def example1():
    84     # 图片列表
    85     filesnames = tf.constant(['name1.jpg','name3.jpg','name5.jpg','name6.jpg','name7.jpg','name8.jpg'])
    86     # 对应标签
    87     labels = tf.constant([0,1,0,1,1,0])
    88     # dataset  (名称,标签)
    89     dataset = tf.data.Dataset.from_tensor_slices((filesnames,labels))
    90     # map 映射成图片和标签
    91     dataset = dataset.map(_parse_function)
    92     # shuffle ,batch , repeat
    93     dataset = dataset.shuffle(buffersize=3).batch(2).repeat(10)
    94     return dataset
    95 
    96 if __name__ == '__main__':
    97     test2()
    View Code
  • 相关阅读:
    mysql5.6 sql_mode设置为宽松模式
    utf-8 编码问题
    阿里云服务器挂载云盘
    maven打包含有多个main程序的jar包及运行方式
    AndroidStudio OpenCv的配置,不用安装opencv manager
    图片标注工具LabelImg使用教程
    关于tensorboard启动问题
    IntelliJ IDEA 最新激活码(截止到2018年10月14日)
    JetBrains C++ IDE CLion配置与评测
    Win10下Clion配置opencv3
  • 原文地址:https://www.cnblogs.com/feihu-h/p/11677443.html
Copyright © 2011-2022 走看看