zoukankan      html  css  js  c++  java
  • 【小白学PyTorch】19 TF2模型的存储与载入

    【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.

    参考目录:

    本文主要讲述TF2.0的模型文件的存储和载入的多种方法。主要分成两类型:模型结构和参数一起载入,模型的结构载入。

    1 模型的构建

    import tensorflow.keras as keras
    
    class CBR(keras.layers.Layer):
        def __init__(self,output_dim):
            super(CBR,self).__init__()
            self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
            self.bn = keras.layers.BatchNormalization(axis=3)
            self.ReLU = keras.layers.ReLU()
    
        def call(self, inputs):
            inputs = self.conv(inputs)
            inputs = self.ReLU(self.bn(inputs))
            return inputs
    
    class MyNet(keras.Model):
        def __init__ (self):
            super(MyNet,self).__init__()
            self.cbr1 = CBR(16)
            self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
            self.cbr2 = CBR(32)
            self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2))
    
        def call(self, inputs):
            inputs = self.maxpool1(self.cbr1(inputs))
            inputs = self.maxpool2(self.cbr2(inputs))
            return inputs
    
    model = MyNet()
    

    部分朋友可以发现,上面的代码就是上一次课程所构建的一个自定义的网络。

    我们现在需要展示这个模型的框架:

    model.build((16,224,224,3))
    print(model.summary())
    

    运行结果为:

    这里需要对网络执行一个构建.build()函数,之后才能生成model.summary()这样的模型的描述。 这是因为模型的参数量是需要知道输入数据的通道数的,假如我们输入的是单通道的图片,那么就是:

    model.build((16,224,224,1))
    print(model.summary())
    

    输出结果为:

    2 结构参数的存储与载入

    model.save('save_model.h5')
    new_model = keras.models.load_model('save_model.h5')
    

    这里并不能保存成功,出现这样的错误:

    大概的意思就是:因为你的模型不是官方的模型,是自定义的,所以并不能同时保存结构和参数。只有官方的模型可以时候上面的保存的方法,同时保存参数和权重;自定义的模型建议只保存参数

    3 参数的存储与载入

    model.save_weights('model_weight')
    new_model = MyNet()
    new_model.load_weights('model_weight')
    

    这样子就可以保存自定义的模型了。在对应的目录下会出现这几个文件:

    我们来看一下原来的模型和载入的模型对于同一个样本给出的结果是否相同:

    # 看一下原来的模型和载入的模型预测相同的样本的输出
    test = tf.ones((1,8,8,3))
    prediction = model.predict(test)
    new_prediction = new_model.predict(test)
    print(prediction,new_prediction)
    >>> [[[[0.02559286]]]] [[[[0.02559286]]]]
    

    结果相同,载入的没有问题~

    4 结构的存储与载入

    结构的存储有两种方法:

    • model.get_config()
    • model.to_json()

    需要注意的是,上面的两个方法和save的问题一样,是不能用在自定义的模型中的,如果你在其中使用了自定义的Layer类,那么只能!只能用save_weights的方式进行保存

    下面依然给出这两种方法的代码,对于简单的、已经封装好的一些网络层构成的网络,是可以使用这些的。我个人还是常用save_weights啦

    # 第一种方法
    config = model.get_config()
    reinitialized_model = keras.Model.from_config(config)
    # 第二种方法
    json_config = model.to_json()
    # 把json写的文件中
    with open('model_config.json', 'w') as json_file:
        json_file.write(json_config)
    # 读取本地json文件
    with open('model_config.json') as json_file:
        json_config = json_file.read()
    reinitialized_model = keras.models.model_from_json(json_config)
    

    今天的内容就是这么多,虽然提供了四种方法,但是对于自定义程度较高的模型,还是要使用save_weights哦~

  • 相关阅读:
    eclipse下c/cpp " undefined reference to " or "launch failed binary not found"问题
    blockdev 设置文件预读大小
    宝宝语录
    CentOS修改主机名(hostname)
    subprocess报No such file or directory
    用ldap方式访问AD域的的错误解释
    英特尔的VTd技术是什么?
    This virtual machine requires the VMware keyboard support driver which is not installed
    Linux内核的文件预读详细详解
    UNP总结 Chapter 26~29 线程、IP选项、原始套接字、数据链路访问
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13776682.html
Copyright © 2011-2022 走看看