zoukankan      html  css  js  c++  java
  • pytorch搭建网络,保存参数,恢复参数

    这是看过莫凡python的学习笔记。

    搭建网络,两种方式

    (1)建立Sequential对象

    import torch
    net = torch.nn.Sequential(
                torch.nn.Linear(2,10),
                torch.nn.ReLU(),
                torch.nn.Linear(10,2))

    输出网络结构

    Sequential(
      (0): Linear(in_features=2, out_features=10, bias=True)
      (1): ReLU()
      (2): Linear(in_features=10, out_features=2, bias=True)
    )

    (2)建立网络类,继承torch.nn.module

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net,self).__init__()
            self.hidden = torch.nn.Linear(2,10)
            self.predict = torch.nn.Linear(10,2)
        def forward(self,x):
            x = F.relu(self.hidden(x))
            x = self.predict(x)
            return x

    输出和上面基本一样,略微不同

    Net(
      (hidden): Linear(in_features=2, out_features=10, bias=True)
      (predict): Linear(in_features=10, out_features=2, bias=True)
    )

    保存模型,两种方式

    (1)保存整个网络,及网络参数

    torch.save(net,'net.pkl')

    (2)只保存网络参数

    torch.save(net.state_dict(),'net_params.pkl')

    恢复模型,两种方式

    (1)加载整个网络,及参数

    net2 = torch.load('net.pkl')

    (2)加载参数,但需实现网络

    net3 = torch.nn.Sequential(
                torch.nn.Linear(2,10),
                torch.nn.ReLU(),
                torch.nn.Linear(10,2))
    net3.load_state_dict(torch.load('net_params.pkl'))
  • 相关阅读:
    IDEA插件和快捷设置
    漫谈虚拟内存
    漫谈进程和线程
    漫谈计算机语言
    初识Python
    数据库物理设计
    漫谈计算机体系
    数据库逻辑设计
    NLP中几种分词库的简单使用(Python)
    ML————朴素贝叶斯原理和SKlearn相关库
  • 原文地址:https://www.cnblogs.com/wzyuan/p/9458008.html
Copyright © 2011-2022 走看看