zoukankan      html  css  js  c++  java
  • TensorFlow-keras fit的callbacks参数,定值保存模型

    from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
    from tensorflow.python.keras.models import Sequential,Model
    from tensorflow.python.keras.layers import Dense,Flatten,Input
    import tensorflow as tf
    from tensorflow.python.keras.losses import sparse_categorical_crossentropy
    from tensorflow.python import keras
    import os
    import numpy as np
    
    class SingleNN(object):
    
        #建立神经网络模型
        model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28,28)),
            keras.layers.Dense(128,activation=tf.nn.relu),
            keras.layers.Dense(10,activation=tf.nn.softmax)
        ])
    
        def __init__(self):
            (self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
            #归一化
            self.x_train = self.x_train/255.0
            self.x_test = self.x_test/255.0
    
        def singlenn_compile(self):
            '''
            编译模型优化器、损失、准确率
            :return:
            '''
            SingleNN.model.compile(
                optimizer=keras.optimizers.SGD(lr=0.01),
                loss=keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy']
            )
    
        def singlenn_fit(self):
            """
            进行fit训练
            :return: 
            """
            # modelcheck = keras.callbacks.ModelCheckpoint("./ckpt/singlenn_{epoch:02d}-{acc:.2f}.h5",
            #                                         # monitor="val_acc", #保存损失还是准确率
            #                                         # save_best_only=True,
            #                                         save_weights_only=True,
            #                                         mode = 'auto',
            #                                         period = 1
            #                                         )
            board = keras.callbacks.TensorBoard(log_dir="./graph",write_graph=True)
            SingleNN.model.fit(self.x_train,self.y_train,epochs=5,callbacks=[board])
    
        def single_evalute(self):
            '''
            模型评估
            :return: 
            '''
            test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
            print(test_loss,test_acc)
    
        def single_predict(self):
            '''
            预测结果
            :return: 
            '''
            # if os.path.exists("./ckpt/checkpoink"):
            #     SingleNN.model.load_weights("./ckpt/SingleNN")
    
            if os.path.exists("./ckpt/SingleNN.h5"):
                SingleNN.model.load_weights("./ckpt/SingleNN.h5")
    
            predictions = SingleNN.model.predict(self.x_test)
    
            return predictions
    
    if __name__ == '__main__':
        snn = SingleNN()
        snn.singlenn_compile()
        snn.singlenn_fit()
        snn.single_evalute()
        # # SingleNN.model.save_weights("./ckpt/SingleNN")
        # SingleNN.model.save_weights("./ckpt/SingleNN.h5")
        # predictions = snn.single_predict()
        # print(predictions)
        # result = np.argmax(predictions,axis=1)
        # print(result)
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    POJ
    HDU
    Python之列表
    列表、元组、字典总结
    Python之列表、原组、字典总结
    [P1082][NOIP2012] 同余方程 (扩展欧几里得/乘法逆元)
    [P3957][NOIP2017]跳房子 (DP+二分/队列?)
    [Codeforces896C] Willem, Chtholly and Seniorious (ODT-珂朵莉树)
    [P1005][NOIP2007] 矩阵取数游戏 (DP+高精)
    [POJ1006]生理周期 (中国剩余定理)
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12250739.html
Copyright © 2011-2022 走看看