zoukankan      html  css  js  c++  java
  • tf2 callback

    摘自b站tf2视频教程

    我们经常使用到的三个回调函数为:

      TensorBoard

      ModelCheckpoint

      EarlyStopping

    可以这样使用:

    logdir = "./callback"
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
    callbacks=[
        k.callbacks.TensorBoard(logdir),
        k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True),
        k.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
    ]
    history=model.fit(x_train,y_train,epochs=10,
             validation_data=(x_valid,y_valid),
             callbacks=callbacks)

    完整代码:

    #!/usr/bin/env python
    # coding: utf-8
    
    # In[2]:
    
    
    import tensorflow as tf
    import tensorflow.keras as k
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    
    
    # In[3]:
    
    
    fashion_mnist = k.datasets.fashion_mnist
    (x_train,y_train),(x_test,y_test)=fashion_mnist.load_data()
    x_train,x_valid = x_train[:5000],x_train[5000:]
    y_train,y_valid= y_train[:5000],y_train[5000:]
    
    
    # In[4]:
    
    
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
    
    
    # In[5]:
    
    
    #build the model
    model =k.Sequential()
    model.add(k.layers.Flatten(input_shape=[28,28]))
    model.add(k.layers.Dense(300,activation="relu"))
    model.add(k.layers.Dense(100,activation="relu"))
    model.add(k.layers.Dense(10,activation="softmax"))
    
    model.compile(loss="sparse_categorical_crossentropy",
                 optimizer="adam",
                 metrics=["accuracy"])
    
    
    # In[7]:
    
    
    
    logdir = "./callback"
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
    callbacks=[
        k.callbacks.TensorBoard(logdir),
        k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True),
        k.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
    ]
    history=model.fit(x_train,y_train,epochs=10,
             validation_data=(x_valid,y_valid),
             callbacks=callbacks)
    
    
    # In[ ]:
    
    
    import pandas as pd
    def plot_curve(history):
        pd.DataFrame(history.history).plot(figsize=(8,5))
        plt.grid(True)
        plt.gca().set_ylim(0,1)
        plt.show()
    plot_curve(history)
    
    
    # In[ ]:
    完整代码
  • 相关阅读:
    关于Eclipse开发插件(三)
    关于Eclipse插件开发(一)
    关于Eclipse中开发插件(二)
    Android-ImageView的属性android:scaleType作用
    bigautocomplete实现联想输入,自动补全
    Sqlite-Sqlite3中的数据类型
    C#/Sqlite-单机Window 程序 sqlite 数据库实现
    C#/Sqlite-SQLite PetaPoco django 打造桌面程序
    桌面轻量级数据库的选择:Access、SQLite、自己编写?
    如何开始创业
  • 原文地址:https://www.cnblogs.com/superxuezhazha/p/12363521.html
Copyright © 2011-2022 走看看