zoukankan      html  css  js  c++  java
  • 保存和提取

    # -*- coding: utf-8 -*-
    import torch
    import matplotlib.pyplot as plt

    # torch.manual_seed(1) # reproducible

    # fake data-11分成100
    x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
    y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)

    # The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
    # x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)


    def save():
    # save net1
    net1 = torch.nn.Sequential(
    torch.nn.Linear(1, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()

    for t in range(1000):
    prediction = net1(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

    # 2 ways to save the net
    torch.save(net1, 'net.pkl') # save entire net
    torch.save(net1.state_dict(), 'net_params.pkl') # save only the parameters


    def restore_net():
    # restore entire net1 to net2
    net2 = torch.load('net.pkl')
    prediction = net2(x)

    # plot result
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)


    def restore_params():
    # restore only the parameters in net1 to net3
    net3 = torch.nn.Sequential(
    torch.nn.Linear(1, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 1)
    )

    # copy net1's parameters into net3
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)

    # plot result
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()

    # save net1
    save()

    # restore entire net (may slow)
    restore_net()

    # restore only the net parameters
    restore_params()
  • 相关阅读:
    angular ngIf指令 以及组件的输入输出
    angular 命令行指令总结
    angular8.x 事件的处理和样式绑定
    nodejs更新版本(windows)
    angular重要指令 ngFor
    emmet 常用总结
    手机真机调试 (ng项目)
    最长回文子串
    最长连续序列
    重复的子字符串
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/10288467.html
Copyright © 2011-2022 走看看