zoukankan      html  css  js  c++  java
  • 线性回归的简洁实现

    1 导入实验所需要的包

    import numpy as np
    import torch
    from torch import nn
    from torch.utils import data
    import matplotlib.pyplot as plt
    #解决内核挂掉
    import os
    os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"

     

    2 生成数据集

      将根据带有噪声的线性模型构造一个人造数据集。任务是使用这个有限样本的数据集来恢复这个模型的参数。这里使用低维数据,这样可以很容易地将其可视化。

      在下面的代码中, 将生成一个包含 1000 个样本的数据集,每个样本包含从标准正态分布中采样的 2 个特征。合成数据集是一个矩阵 $mathbf{X} in mathbb{R}^{1000 imes 2} $ 。
      使用线性模型参数 $ mathbf{w}=[2,-3.4]^{ op}$ 、 $b=4.2$ 和噪声项 $epsilon$  生成数据集及其标签:

        $mathbf{y}=mathbf{X} mathbf{w}+b+epsilon$

      可以将 $epsilon$ 视为捕获特征和标签时的潜在观测误差。在这里认为标准假设成立,即 $epsilon $ 服从均值为 $0$ 的正态分布。为了简化问题, 我们将标准差设 为 $0.01 $ 。下面的代码生成合成数据集。

    def get_random_data(w,b,num_example):
        X = torch.normal(0,1,(num_example,len(w)))
        #X = torch.normal(0,1,(num_example,2))
        Y = torch.matmul(X,w)+b      #矩阵乘法,要求稍微低一点   
        Y += torch.normal(0,0.01,Y.shape)
        return X,Y.reshape(-1,1)
    
    true_w = torch.tensor([2,-3.4])
    true_b = 4.2
    features ,labels = get_random_data(true_w,true_b,1000)

     

    3 可视化初始数据

    plt.rcParams['figure.figsize']=(12,4)
    plt.subplot(1,2,1)
    plt.scatter(features[:,0],labels,s=2)
    plt.subplot(1,2,2)
    plt.scatter(features[:,1],labels,s=2)

     

    4 读取数据集

    def load_array(data_arrays, batch_size, is_train=True): 
        """构造一个PyTorch数据迭代器。"""
        dataset = data.TensorDataset(*data_arrays)
        return data.DataLoader(dataset, batch_size, shuffle=is_train)
    
    batch_size = 10
    data_iter = load_array((features, labels), batch_size)
    # next(iter(data_iter))

     

    5 定义模型

    #定义模型方法一:
    # net = nn.Sequential(nn.Linear(2, 1))
    # print(net[0])
    
    #定义模型方法二:
    # from collections import OrderedDict
    # net = nn.Sequential(OrderedDict([('mylinear',nn.Linear(2, 1))]))
    # print(net[0])
    
    #定义模型方法三:
    net = nn.Sequential()
    net.add_module('mylinear',nn.Linear(2, 1))
    print(net[0].weight)
    print(net[0].bias)
    print(net[0].weight.grad)
    
    print("net[0] = ",net[0])
    
    for index,param in enumerate(net.state_dict()):
        print("index = ",index)
        print("param = ",param)
        print("param_value = ",net.state_dict()[param])
        print('----------------')

    Parameter containing: tensor([[-0.6449, -0.3191]], requires_grad=True) Parameter containing: tensor([0.1766], requires_grad=True) None net[0] = Linear(in_features=2, out_features=1, bias=True) index = 0 param = mylinear.weight param_value = tensor([[-0.6449, -0.3191]]) ---------------- index = 1 param = mylinear.bias param_value = tensor([0.1766]) ----------------

    6 初始化模型参数

      正如我们在构造 nn.Linear 时指定输入和输出尺寸一样。

      现在我们直接访问参数以设定初始值。我们通过 net[0] 选择网络中的第一个图层,然后使用 weight.data 和 bias.data 方法访问参数。然后使用替换方法 normal_ 和 fill_ 来重写参数值。

    from torch.nn import init
    init.normal_(net[0].weight,mean=0,std=0.01)
    init.constant_(net[0].bias,val = 0)
    
    print(net[0].weight)
    print(net[0].bias)
    print(net[0].weight.grad)
    Parameter containing:
    tensor([[-0.0027, -0.0086]], requires_grad=True)
    Parameter containing:
    tensor([0.], requires_grad=True)
    None

    7 定义损失函数

    loss = nn.MSELoss()

     

    8 定义优化算法

    optimizer = torch.optim.SGD(net.parameters(), lr=0.03)

     

    9 训练

    num_epochs = 3
    for epoch in range(num_epochs):
        for X, y in data_iter:
            l = loss(net(X) ,y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        l = loss(net(features), labels)
        print(f'epoch {epoch + 1}, loss {l:f}')
    epoch 1, loss 0.000228
    epoch 2, loss 0.000102
    epoch 3, loss 0.000103

    因上求缘,果上努力~~~~ 作者:希望每天涨粉,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15428515.html

  • 相关阅读:
    [图解算法] 最短路径算法之 “Dijikstra”
    [前端随笔][CSS] 伪类的应用
    [前端随笔][JavaScript] 实现原生的事件监听<Vue原理>
    [图解算法] 最短路径算法之 “Floyd”
    [前端随笔][JavaScript][自制数据可视化] “中国地图”
    [前端随笔][JavaScript] 懒加载的实现(上划一次加载一部分)
    [前端随笔][CSS] 制作一个加载动画 即帖即用
    ThinkPHP下隐藏index.php以及URL伪静态
    PHP基础语法3
    PHP基础语法2
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15428515.html
Copyright © 2011-2022 走看看