zoukankan      html  css  js  c++  java
  • PyTorch练手项目一:训练一个简单的线性回归

    本文目的:展示如何利用PyTorch做一个简单的线性回归。

    1 随机生成一些数据

    #导入相关库
    import torch
    import torch.nn as nn
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    print("torch version: ", torch.__version__)  #torch version:  1.1.0
    
    
    #随机生成一些点,并做成DataFrame
    x = np.linspace(0, 5, 256)
    noise = np.random.randn(256) * 2
    y = x * 5 + 7 + noise
    df = pd.DataFrame()
    df['x'] = x
    df['y'] = y
    
    #可视化
    sns.lmplot(x='x', y='y', data=df, height=4)
    

    2 利用Pytorch进行线性回归

    三部曲:准备数据,准备模型,训练。

    #准备数据
    train_x = x.reshape(-1, 1).astype('float32')
    train_y = y.reshape(-1, 1).astype('float32')
    train_x = torch.from_numpy(train_x)
    train_y = torch.from_numpy(train_y)
    
    
    #准备模型
    model = nn.Linear(1, 1)
    
    
    #定义训练参数
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    epochs = 3000
    
    #开始训练
    for i in range(1, epochs+1):
        optimizer.zero_grad()
        out = model(train_x)
        loss = loss_fn(out, train_y)
        loss.backward()
        optimizer.step()
        if(i % 300 == 0):
            print('epoch {}  loss {:.4f}'.format(i, loss.item())) 
    

    3 结果可视化

    #获取参数值
    w, b = model.parameters()  #parameters()返回的是一个迭代器指向的对象
    print(w.item(), b.item())  
    
    #结果可视化
    #model返回的是总tensor,包含grad_fn,用data提取出的tensor是纯tensor
    pred = model.forward(train_x).data.numpy().squeeze() 
    plt.plot(x, y, 'go', label='Truth', alpha=0.3)
    plt.plot(x, pred, label='Predicted')
    plt.legend()
    plt.show()
    

    4 小结

    • 数据生成和可视化方法

    Reference

  • 相关阅读:
    dbcp2连接池获取数据库连接Connection
    ItelliJ基于Gradle创建及发布Web项目(三)
    freeswitch编译java esl
    Java程序(非web)slf4j整合Log4j2
    日期常用操作类DateUtil
    关于静态库
    Activity的setContentView的流程
    ProGuard详解
    remoteViews简介
    WMS—启动过程
  • 原文地址:https://www.cnblogs.com/inchbyinch/p/12120377.html
Copyright © 2011-2022 走看看