zoukankan      html  css  js  c++  java
  • 使用pytorch进行线性回归


    x,y
    3.3,1.7 4.4,2.76 5.5,2.09 6.71,3.19 6.93,1.694 4.168,1.573 9.779,3.366 6.182,2.596 7.59,2.53 2.167,1.221 7.042,2.827 10.791,3.465 5.313,1.65 7.997,2.904 3.1,1.3

    以上是欲拟合数据

    import torch
    from torch import nn, optim
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    d = pd.read_csv("data.csv")
    x_train = np.array(d.x[:],dtype=np.float32).reshape(15,1)
    
    print(x_train)
    y_train=np.array(d.y[:],dtype=np.float32).reshape(15,1)
    print(y_train)
    
    x_train = torch.from_numpy(x_train)
    
    y_train = torch.from_numpy(y_train)
    
    
    # Linear Regression Model
    class LinearRegression(nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)  # input and output is 1 dimension
    
        def forward(self, x):
            out = self.linear(x)
            return out
    
    
    model = LinearRegression()
    # 定义loss和优化函数
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    
    # 开始训练
    num_epochs = 1000
    for epoch in range(num_epochs):
        inputs = Variable(x_train)
        target = Variable(y_train)
    
        # forward
        out = model(inputs)
        loss = criterion(out, target)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if (epoch+1) % 20 == 0:
            print('Epoch[{}/{}], loss: {:.6f}'
                  .format(epoch+1, num_epochs, loss.data[0]))
    
    model.eval()
    predict = model(Variable(x_train))
    predict = predict.data.numpy()
    plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
    plt.plot(x_train.numpy(), predict, label='Fitting Line')
    # 显示图例
    plt.legend()
    plt.show()
    
    # 保存模型
    torch.save(model.state_dict(), './linear.pth')
    

      

  • 相关阅读:
    Kafk为什么这么快
    kafka消息格式演变
    kafka基础命令及api使用
    zookeeper && kafka && kafka manager开机自启动设置
    sqoop进行将Hive 词频统计的结果数据传输到Mysql中
    hive实例的使用
    使用HBase Shell命令
    Hadoop使用实例 词频统计和气象分析
    HDFS 操作命令
    第四次作业 描述HDFS体系结构、工作原理与流程
  • 原文地址:https://www.cnblogs.com/dudu1992/p/8980249.html
Copyright © 2011-2022 走看看