zoukankan      html  css  js  c++  java
  • pytorch-API实现线性回归

    示例:

    import torch
    import torch.nn as nn
    from torch import optim
    
    class MyModel(nn.Module):
    
        def __init__(self):
            super(MyModel,self).__init__()
            self.lr = nn.Linear(1,1)
    
    
        def forward(self,x):
            return self.lr(x)
    
    #准备数据
    
    x= torch.rand([500,1])
    y_true = 3*x+0.8
    #1.实例化模型
    model = MyModel()
    #2.实例化优化器
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    #3.实例化损失函数
    loss_fn = nn.MSELoss()
    
    for i in range(500):
        #4.梯度置为0
        optimizer.zero_grad()
        #5.调用模型得到预测值
        y_predict = model(x)
        #6.通过损失函数,计算得到损失
        loss = loss_fn(y_predict,y_true)
        #7.反向传播,计算梯度
        loss.backward()
        #8.更新参数
        optimizer.step()
    
        #打印部分数据
        if i%10 ==0:
            print(i,loss.item())
    
    for param in model.parameters():
        print(param.item())
    

      

    使用英伟达显卡CUDA模式加速计算:

    import torch
    import torch.nn as nn
    from torch import optim
    import time
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    class MyModel(nn.Module):

    def __init__(self):
    super(MyModel,self).__init__()
    self.lr = nn.Linear(1,1)


    def forward(self,x):
    return self.lr(x)

    #准备数据

    x= torch.rand([500,1]).to(device=device)
    y_true = 3*x+0.8
    #1.实例化模型
    model = MyModel().to(device)
    #2.实例化优化器
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    #3.实例化损失函数
    loss_fn = nn.MSELoss()
    start = time.time()
    for i in range(500):
    #4.梯度置为0
    optimizer.zero_grad()
    #5.调用模型得到预测值
    y_predict = model(x)
    #6.通过损失函数,计算得到损失
    loss = loss_fn(y_predict,y_true)
    #7.反向传播,计算梯度
    loss.backward()
    #8.更新参数
    optimizer.step()

    #打印部分数据
    if i%10 ==0:
    print(i,loss.item())

    for param in model.parameters():
    print(param.item())

    end = time.time()

    print(end-start)

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    css样式学习笔记
    Css教程玉女心经版本
    weblogic高级进阶之ssl配置证书
    weblogic高级进阶之查看日志
    weblogic之高级进阶JMS的应用
    【WebLogic使用】3.WebLogic配置jndi数据源
    shiro的helloworld
    尚硅谷spring 事物管理
    尚硅谷spring aop详解
    Spring Boot 2.x Redis多数据源配置(jedis,lettuce)
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/11379953.html
Copyright © 2011-2022 走看看