zoukankan      html  css  js  c++  java
  • 使用pytorch进行线性回归


    x,y
    3.3,1.7 4.4,2.76 5.5,2.09 6.71,3.19 6.93,1.694 4.168,1.573 9.779,3.366 6.182,2.596 7.59,2.53 2.167,1.221 7.042,2.827 10.791,3.465 5.313,1.65 7.997,2.904 3.1,1.3

    以上是欲拟合数据

    import torch
    from torch import nn, optim
    from torch.autograd import Variable
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    d = pd.read_csv("data.csv")
    x_train = np.array(d.x[:],dtype=np.float32).reshape(15,1)
    
    print(x_train)
    y_train=np.array(d.y[:],dtype=np.float32).reshape(15,1)
    print(y_train)
    
    x_train = torch.from_numpy(x_train)
    
    y_train = torch.from_numpy(y_train)
    
    
    # Linear Regression Model
    class LinearRegression(nn.Module):
        def __init__(self):
            super(LinearRegression, self).__init__()
            self.linear = nn.Linear(1, 1)  # input and output is 1 dimension
    
        def forward(self, x):
            out = self.linear(x)
            return out
    
    
    model = LinearRegression()
    # 定义loss和优化函数
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    
    # 开始训练
    num_epochs = 1000
    for epoch in range(num_epochs):
        inputs = Variable(x_train)
        target = Variable(y_train)
    
        # forward
        out = model(inputs)
        loss = criterion(out, target)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if (epoch+1) % 20 == 0:
            print('Epoch[{}/{}], loss: {:.6f}'
                  .format(epoch+1, num_epochs, loss.data[0]))
    
    model.eval()
    predict = model(Variable(x_train))
    predict = predict.data.numpy()
    plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label='Original data')
    plt.plot(x_train.numpy(), predict, label='Fitting Line')
    # 显示图例
    plt.legend()
    plt.show()
    
    # 保存模型
    torch.save(model.state_dict(), './linear.pth')
    

      

  • 相关阅读:
    信息安全系统设计基础第九周学习总结
    信息安全程序设计基础第五周学习总结
    信息安全程序设计基础第二周学习总结
    信息安全程序设计基础第三周总结
    ubuntu 13.10安装jdk 1.7 owen
    vim的配置文件 owen
    程序的思想是相通的,语言只是一种手段 owen
    如何删除开机系统选择 owen
    easybcd添加或删除启动选项 owen
    星际译王词库 owen
  • 原文地址:https://www.cnblogs.com/dudu1992/p/8980249.html
Copyright © 2011-2022 走看看