zoukankan      html  css  js  c++  java
  • tensorflow的keras实现搭配dataset 之一

    tensorflow的keras实现搭配dataset,几种形式都工作!

    tensorflow,keras Sequential模式下:

    见代码:

    from tensorflow import keras as ks
    import tensorflow as tf
    
    # Generate dummy data
    import numpy as np
    x_train = np.random.random((1000, 20))
    y_train = ks.utils.to_categorical(np.random.randint(10, size=(1000, 1)), num_classes=10)
    x_test = np.random.random((100, 20))
    y_test = ks.utils.to_categorical(np.random.randint(10, size=(100, 1)), num_classes=10)
    
    
    batch_size = 100
    steps_per_epoch = int(np.ceil(x_train.shape[0]/batch_size))
    
    train_ds = tf.data.Dataset.from_tensor_slices((x_train,y_train))
    train_ds = train_ds.batch(batch_size)   # batch 能给数据集增加批维度
    train_ds = train_ds.repeat()
    
    train_it = train_ds.make_one_shot_iterator()
    x_train_it, y_train_it = train_it.get_next()
    
    
    test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    test_ds = test_ds.batch(batch_size)
    test_ds = test_ds.repeat()
    
    model = ks.models.Sequential()
    # Dense(64) is a fully-connected layer with 64 hidden units.
    # in the first layer, you must specify the expected input data shape:
    # here, 20-dimensional vectors.
    model.add(ks.layers.Dense(64, activation='relu', input_dim=20))
    model.add(ks.layers.Dropout(0.5))
    model.add(ks.layers.Dense(64, activation='relu'))
    model.add(ks.layers.Dropout(0.5))
    model.add(ks.layers.Dense(10, activation='softmax'))
    
    sgd = ks.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss='categorical_crossentropy', optimizer=sgd,   metrics=['accuracy'])
    
    # passing the data to the model with the below to style, both work
    model.fit(x_train_it, y_train_it, epochs=20, steps_per_epoch=steps_per_epoch)
    print("(+("*20,'
    '*4)
    model.fit(train_ds, epochs=20, steps_per_epoch=steps_per_epoch)
    
    score = model.evaluate(test_ds, steps=128)
    print(score)
  • 相关阅读:
    Tomcat6 一些调优设置内存和连接数
    【原创】使用c3p0数据库连接池时出现com.mchange.v2.resourcepool.TimeoutException
    JVM内存的设置
    JBOSS以及tomcat最大连接数配置和jvm内存配置
    摘抄python __init__
    Python中__init__方法介绍
    Python 绝对简明手册
    python中eval, exec, execfile,和compile [转载]
    extern、static、auto、register 定义变量的不同用法
    Python 网络编程说明
  • 原文地址:https://www.cnblogs.com/wdmx/p/10256713.html
Copyright © 2011-2022 走看看