zoukankan      html  css  js  c++  java
  • 使用pytorch进行线性回归


    x,y
    3.3,1.7 4.4,2.76 5.5,2.09 6.71,3.19 6.93,1.694 4.168,1.573 9.779,3.366 6.182,2.596 7.59,2.53 2.167,1.221 7.042,2.827 10.791,3.465 5.313,1.65 7.997,2.904 3.1,1.3

    以上是欲拟合数据

    import torch
    from torch import nn, optim
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    d = pd.read_csv("data.csv")
    x_train = np.array(d.x[:],dtype=np.float32).reshape(15,1)
    
    print(x_train)
    y_train=np.array(d.y[:],dtype=np.float32).reshape(15,1)
    print(y_train)
    
    x_train = torch.from_numpy(x_train)
    
    y_train = torch.from_numpy(y_train)
    
    
    # Linear Regression Model
    class LinearRegression(nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)  # input and output is 1 dimension
    
        def forward(self, x):
            out = self.linear(x)
            return out
    
    
    model = LinearRegression()
    # 定义loss和优化函数
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    
    # 开始训练
    num_epochs = 1000
    for epoch in range(num_epochs):
        inputs = Variable(x_train)
        target = Variable(y_train)
    
        # forward
        out = model(inputs)
        loss = criterion(out, target)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if (epoch+1) % 20 == 0:
            print('Epoch[{}/{}], loss: {:.6f}'
                  .format(epoch+1, num_epochs, loss.data[0]))
    
    model.eval()
    predict = model(Variable(x_train))
    predict = predict.data.numpy()
    plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
    plt.plot(x_train.numpy(), predict, label='Fitting Line')
    # 显示图例
    plt.legend()
    plt.show()
    
    # 保存模型
    torch.save(model.state_dict(), './linear.pth')
    

      

  • 相关阅读:
    php学习【1】
    网页项目源码笔记
    python学习笔记
    php集成开发环境xampp的搭建
    ubuntu18.04.1LTS系统远程工具secureCRT
    关于IT人的一些消遣区
    linux系统的启动过程简要分析
    【shell脚本学习-1】
    Linux命令总结--cat命令
    Linux命令总结--vi/vim命令
  • 原文地址:https://www.cnblogs.com/dudu1992/p/8980249.html
Copyright © 2011-2022 走看看