zoukankan      html  css  js  c++  java
  • 使用pytorch进行线性回归


    x,y
    3.3,1.7 4.4,2.76 5.5,2.09 6.71,3.19 6.93,1.694 4.168,1.573 9.779,3.366 6.182,2.596 7.59,2.53 2.167,1.221 7.042,2.827 10.791,3.465 5.313,1.65 7.997,2.904 3.1,1.3

    以上是欲拟合数据

    import torch
    from torch import nn, optim
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    d = pd.read_csv("data.csv")
    x_train = np.array(d.x[:],dtype=np.float32).reshape(15,1)
    
    print(x_train)
    y_train=np.array(d.y[:],dtype=np.float32).reshape(15,1)
    print(y_train)
    
    x_train = torch.from_numpy(x_train)
    
    y_train = torch.from_numpy(y_train)
    
    
    # Linear Regression Model
    class LinearRegression(nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)  # input and output is 1 dimension
    
        def forward(self, x):
            out = self.linear(x)
            return out
    
    
    model = LinearRegression()
    # 定义loss和优化函数
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    
    # 开始训练
    num_epochs = 1000
    for epoch in range(num_epochs):
        inputs = Variable(x_train)
        target = Variable(y_train)
    
        # forward
        out = model(inputs)
        loss = criterion(out, target)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if (epoch+1) % 20 == 0:
            print('Epoch[{}/{}], loss: {:.6f}'
                  .format(epoch+1, num_epochs, loss.data[0]))
    
    model.eval()
    predict = model(Variable(x_train))
    predict = predict.data.numpy()
    plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
    plt.plot(x_train.numpy(), predict, label='Fitting Line')
    # 显示图例
    plt.legend()
    plt.show()
    
    # 保存模型
    torch.save(model.state_dict(), './linear.pth')
    

      

  • 相关阅读:
    客户端session与服务端session
    对session和cookie的一些理解
    Servlet生命周期与工作原理
    Linux命令行编辑快捷键
    含有GROUP BY子句的查询中如何显示COUNT()为0的成果(分享)
    设计模式学习——准备(UML类图)
    find()方法
    js中的动态效果
    动态添加,移除,查找css属性的方法
    VUE中的require ( )
  • 原文地址:https://www.cnblogs.com/dudu1992/p/8980249.html
Copyright © 2011-2022 走看看