zoukankan      html  css  js  c++  java
  • 使用pytorch进行线性回归


    x,y
    3.3,1.7 4.4,2.76 5.5,2.09 6.71,3.19 6.93,1.694 4.168,1.573 9.779,3.366 6.182,2.596 7.59,2.53 2.167,1.221 7.042,2.827 10.791,3.465 5.313,1.65 7.997,2.904 3.1,1.3

    以上是欲拟合数据

    import torch
    from torch import nn, optim
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    d = pd.read_csv("data.csv")
    x_train = np.array(d.x[:],dtype=np.float32).reshape(15,1)
    
    print(x_train)
    y_train=np.array(d.y[:],dtype=np.float32).reshape(15,1)
    print(y_train)
    
    x_train = torch.from_numpy(x_train)
    
    y_train = torch.from_numpy(y_train)
    
    
    # Linear Regression Model
    class LinearRegression(nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)  # input and output is 1 dimension
    
        def forward(self, x):
            out = self.linear(x)
            return out
    
    
    model = LinearRegression()
    # 定义loss和优化函数
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    
    # 开始训练
    num_epochs = 1000
    for epoch in range(num_epochs):
        inputs = Variable(x_train)
        target = Variable(y_train)
    
        # forward
        out = model(inputs)
        loss = criterion(out, target)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if (epoch+1) % 20 == 0:
            print('Epoch[{}/{}], loss: {:.6f}'
                  .format(epoch+1, num_epochs, loss.data[0]))
    
    model.eval()
    predict = model(Variable(x_train))
    predict = predict.data.numpy()
    plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
    plt.plot(x_train.numpy(), predict, label='Fitting Line')
    # 显示图例
    plt.legend()
    plt.show()
    
    # 保存模型
    torch.save(model.state_dict(), './linear.pth')
    

      

  • 相关阅读:
    PAT (Advanced Level) Practise:1008. Elevator
    练习题-二维数组中的查找
    PAT (Basic Level) Practise:1028. 人口普查
    PAT (Basic Level) Practise:1014. 福尔摩斯的约会
    PAT (Basic Level) Practise:1019. 数字黑洞
    c++ 二进制文件读写
    c/c++ linux/windows 读取目录下的所有文件名
    C 语言实现 Linux touch 命令
    c++读写csv
    linux nohup【转】
  • 原文地址:https://www.cnblogs.com/dudu1992/p/8980249.html
Copyright © 2011-2022 走看看