zoukankan      html  css  js  c++  java
  • pytorch Model Linear实现线性回归CUDA版本

    实验代码

    import torch
    import torch.nn as nn
    
    #y = wx + b
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel,self).__init__()
            #自定义代码
            # self.w = torch.rand([500,1],requires_grad=True)
            # self.b = torch.tensor(0,dtype=torch.float,requires_grad=True)
            # self.lr = nn.Linear(1,1)
            self.lr1 = nn.Linear(1,10)
            # self.lr2 = nn.Linear(10,20)
            # self.lr3 = nn.Linear(20,1)
    
    
        def forward(self,x):   #完成一次前项计算
            # y_predict = x*self.w + self.b
            # return y_predict
            # return self.lr(x)
            out1 = self.lr1(x)
            # out2 = self.lr2(out1)
            # out = self.lr3(out2)
            return out1
    
    
    
    if __name__ == '__main__':
        model = MyModel()
        # print(model.parameters())
        for i in model.parameters():
            print(i)
            print("*"*100)
        # y_predict = model(torch.FloatTensor([10]))
        # print(y_predict)
    

      Linear实现线性回归,cuda版本

    import torch
    import torch.nn as nn
    from torch import optim
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel,self).__init__()
            self.lr = nn.Linear(1,1)
    
        def forward(self,x):
            return self.lr(x)
    
    #准备数据  如果使用cuda,数据和模型需要to(device)
    x = torch.rand([500,1]).to(device)
    y_true = 3*x + 0.8
    #实例化模型
    model = MyModel().to(device)
    #实例化优化器
    optimizer = optim.Adam(model.parameters(),lr=0.1)
    #实例化损失函数
    loss_fn = nn.MSELoss()
    
    for i in range(500):
        #梯度置零
        optimizer.zero_grad()
        #调用模型得到预测值
        y_predict = model(x)
        #损失函数,计算损失
        loss = loss_fn(y_predict,y_true)
        #反向传播计算梯度
        loss.backward()
        #更新参数
        optimizer.step()
        #打印部分数据
        if i%10 ==0:
            print(i,loss.item())
    
    for param in model.parameters():
        print(param.item())
    

      

    多思考也是一种努力,做出正确的分析和选择,因为我们的时间和精力都有限,所以把时间花在更有价值的地方。
  • 相关阅读:
    ubuntu 14.04搭建PHP项目基本流程
    linux下 lvm 磁盘扩容
    LVM基本介绍与常用命令
    Linux LVM逻辑卷配置过程详解
    mysql 5.7中的用户权限分配相关解读!
    linux系统维护时的一些小技巧,包括系统挂载新磁盘的方法!可收藏!
    linux系统内存爆满的解决办法!~
    源、更新源时容易出现的问题解决方法
    NV显卡Ubuntu14.04更新软件导致登录死循环,不过可以进入tty模式
    一些要注意的地方
  • 原文地址:https://www.cnblogs.com/LiuXinyu12378/p/12300478.html
Copyright © 2011-2022 走看看