zoukankan      html  css  js  c++  java
  • 一文讲解TensorFlow数据接口 tf.data.Dataset

    导入数据

    X = pd.read_csv('./datasets/housing/housing.csv')
    X = X.sample(n=10)
    X.drop(columns = X.columns.difference(['longitude']), inplace=True)
    

    为了避免报错,先进行格式转换:

    X = np.asarray(X).astype(np.float32)
    
    dataset = tf.data.Dataset.from_tensor_slices(X)
    for _ in dataset:
        print(_)
    
    tf.Tensor([-118.75], shape=(1,), dtype=float32)
    tf.Tensor([-119.25], shape=(1,), dtype=float32)
    tf.Tensor([-118.18], shape=(1,), dtype=float32)
    tf.Tensor([-118.13], shape=(1,), dtype=float32)
    tf.Tensor([-118.2], shape=(1,), dtype=float32)
    tf.Tensor([-117.25], shape=(1,), dtype=float32)
    tf.Tensor([-117.93], shape=(1,), dtype=float32)
    tf.Tensor([-122.96], shape=(1,), dtype=float32)
    tf.Tensor([-121.77], shape=(1,), dtype=float32)
    tf.Tensor([-121.24], shape=(1,), dtype=float32)
    
    dataset = dataset.repeat(3).batch(10)
    for _ in dataset:
        print(_)
    

    图解:

    repeat(3)将数据集重复3次,batch(10)每次输出一个包括10个元素的batch。

    tf.Tensor(
    [[-118.75]
     [-119.25]
     [-118.18]
     [-118.13]
     [-118.2 ]
     [-117.25]
     [-117.93]
     [-122.96]
     [-121.77]
     [-121.24]], shape=(10, 1), dtype=float32)
    tf.Tensor(
    [[-118.75]
     [-119.25]
     [-118.18]
     [-118.13]
     [-118.2 ]
     [-117.25]
     [-117.93]
     [-122.96]
     [-121.77]
     [-121.24]], shape=(10, 1), dtype=float32)
    tf.Tensor(
    [[-118.75]
     [-119.25]
     [-118.18]
     [-118.13]
     [-118.2 ]
     [-117.25]
     [-117.93]
     [-122.96]
     [-121.77]
     [-121.24]], shape=(10, 1), dtype=float32)
    

    如果不能刚好等分,例如

    dataset = dataset.repeat(3).batch(9)
    for _ in dataset:
        print(_)
    

    最后一个batch将包含剩下的元素

    tf.Tensor(
    [[-122.08]
     [-121.37]
     [-118.32]
     [-122.38]
     [-122.09]
     [-122.1 ]
     [-122.27]
     [-121.49]
     [-120.68]], shape=(9, 1), dtype=float64)
    tf.Tensor(
    [[-118.2 ]
     [-122.08]
     [-121.37]
     [-118.32]
     [-122.38]
     [-122.09]
     [-122.1 ]
     [-122.27]
     [-121.49]], shape=(9, 1), dtype=float64)
    tf.Tensor(
    [[-120.68]
     [-118.2 ]
     [-122.08]
     [-121.37]
     [-118.32]
     [-122.38]
     [-122.09]
     [-122.1 ]
     [-122.27]], shape=(9, 1), dtype=float64)
    tf.Tensor(
    [[-121.49]
     [-120.68]
     [-118.2 ]], shape=(3, 1), dtype=float64)
    

    map函数

    dataset = dataset.map(lambda x: abs(x))
    for _ in dataset:
        print(_)
    
    tf.Tensor(
    [[118.75]
     [119.25]
     [118.18]
     [118.13]
     [118.2 ]
     [117.25]
     [117.93]
     [122.96]
     [121.77]
     [121.24]], shape=(10, 1), dtype=float32)
    tf.Tensor(
    [[118.75]
     [119.25]
     [118.18]
     [118.13]
     [118.2 ]
     [117.25]
     [117.93]
     [122.96]
     [121.77]
     [121.24]], shape=(10, 1), dtype=float32)
    tf.Tensor(
    [[118.75]
     [119.25]
     [118.18]
     [118.13]
     [118.2 ]
     [117.25]
     [117.93]
     [122.96]
     [121.77]
     [121.24]], shape=(10, 1), dtype=float32)
    

    filter函数

    使用filter函数前需要先unbatch

    dataset = dataset.unbatch()
    dataset = dataset.filter(lambda x: x < 120)
    
  • 相关阅读:
    PHPCMS网站关站了打不开-站长真的凉了吗?
    PHPCMS倒闭关站后,国内CMS系统该何去何从
    企业网站建设如何选择cms建站系统
    网站建设之常用CMS系统的SEO优化特点总结
    PageAdmin CMS仿站教程,如此简单就可以自己建网站
    c#之lamda表达式的前世今生
    c#之Linq的原理讲解及封装自己的Linq
    三大CMS建站系统助你免费建网站
    网站建设的完整流程来了,新手必看
    从零自学Java-7.使用数组存储信息
  • 原文地址:https://www.cnblogs.com/yaos/p/12757972.html
Copyright © 2011-2022 走看看