zoukankan      html  css  js  c++  java
  • fashion_mnist多分类训练,两种模型的保存与加载

    from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
    from tensorflow.python.keras.models import Sequential,Model
    from tensorflow.python.keras.layers import Dense,Flatten,Input
    import tensorflow as tf
    from tensorflow.python.keras.losses import sparse_categorical_crossentropy
    from tensorflow.python import keras
    import os
    import numpy as np
    
    class SingleNN(object):
    
        #建立神经网络模型
        model = keras.Sequential([
            keras.layers.Flatten(input_shape=(28,28)),
            keras.layers.Dense(128,activation=tf.nn.relu),
            keras.layers.Dense(10,activation=tf.nn.softmax)
        ])
    
        def __init__(self):
            (self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
            #归一化
            self.x_train = self.x_train/255.0
            self.x_test = self.x_test/255.0
    
        def singlenn_compile(self):
            '''
            编译模型优化器、损失、准确率
            :return:
            '''
            SingleNN.model.compile(
                optimizer=keras.optimizers.SGD(lr=0.01),
                loss=keras.losses.sparse_categorical_crossentropy,
                metrics=['accuracy']
            )
    
        def singlenn_fit(self):
            """
            进行fit训练
            :return: 
            """
            SingleNN.model.fit(self.x_train,self.y_train,epochs=5)
    
        def single_evalute(self):
            '''
            模型评估
            :return: 
            '''
            test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
            print(test_loss,test_acc)
    
        def single_predict(self):
            '''
            预测结果
            :return: 
            '''
            # if os.path.exists("./ckpt/checkpoink"):
            #     SingleNN.model.load_weights("./ckpt/SingleNN")
    
            if os.path.exists("./ckpt/SingleNN.h5"):
                SingleNN.model.load_weights("./ckpt/SingleNN.h5")
    
            predictions = SingleNN.model.predict(self.x_test)
    
            return predictions
    
    if __name__ == '__main__':
        snn = SingleNN()
        # snn.singlenn_compile()
        # snn.singlenn_fit()
        # snn.single_evalute()
        # # SingleNN.model.save_weights("./ckpt/SingleNN")
        # SingleNN.model.save_weights("./ckpt/SingleNN.h5")
        predictions = snn.single_predict()
        print(predictions)
        result = np.argmax(predictions,axis=1)
        print(result)
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    IOS总结_IOS经常使用的方法集合、调用系统电话、设备区分、APP内永不锁屏
    huffman编码——原理与实现
    python字典构造函数dict(mapping)解析
    tomcat配置sqlserver数据库
    Tomcat全攻略
    第一次QQ群视频教育有感
    UIControl-IOS开发
    java内存分析总结
    Android笔记 之 旋转木马的音乐效果
    Android中API建议的方式实现SQLite数据库的增、删、改、查的操作
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12250596.html
Copyright © 2011-2022 走看看