zoukankan      html  css  js  c++  java
  • 笔记1:入门实例

    pytorch实现线性回归

    导入相关python包

    import torch
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from torch import nn
    %matplotlib inline
    

    加载数据

    data = pd.read_csv('E:/datasets/dataset/Income1.csv')
    X = torch.from_numpy(data.Education.values.reshape(-1, 1).astype(np.float32))
    Y = torch.from_numpy(data.Income.values.reshape(-1, 1).astype(np.float32))
    

    定义模型

    model = nn.Linear(in_features = 1, out_features = 1) # w * input + b 等价于 model(input)
    loss_func = nn.MSELoss() # 损失函数
    optimizer = torch.optim.SGD(params = model.parameters(), lr = 0.0001)
    

    训练模型

    for epoch in range(5000):
        for x, y in zip(X, Y):
            y_pred = model(x)             # 使用模型预测
            loss   = loss_func(y, y_pred) # 根据预测结果计算损失
            optimizer.zero_grad()         # 把变量梯度清 0
            loss.backward()               # 求解梯度
            optimizer.step()              # 优化模型参数
    

    查看训练结果

    model.weight, model.bias
    

    plt.scatter(data.Education, data.Income)
    plt.plot(X.numpy(), model(X).data.numpy(), c = 'r')
    

  • 相关阅读:
    Linux进阶之Linux中的标准输入输出
    PermCheck
    FrogRiverOne
    PermMissingElem
    FrogJmp
    TapeEquilibrium
    恒生电子长沙2016实习生笔试题
    接口和抽象类的异同点?
    C#实现二叉树
    C#实现栈和队列
  • 原文地址:https://www.cnblogs.com/miraclepbc/p/14329186.html
Copyright © 2011-2022 走看看