zoukankan      html  css  js  c++  java
  • pytorch深度学习:线性回归

    深度学习版线性回归,哈哈哈哈。

    在潘神的嘲笑下,我发了这篇博客,呜呜。

     1 import torch
     2 from torch import nn,optim
     3 
     4 class LR(nn.Module):
     5     def __init__(self):
     6         super(LR, self).__init__()
     7         self.linear=nn.Linear(1,1)
     8 
     9     def forward(self, x):
    10         out = self.linear(x)
    11         return out
    12 
    13     def train(self,inputs,target,criterion,optimizer,epoches):
    14         for epoch in range(epoches):
    15             output = model.forward(inputs)
    16             loss = criterion(output, target)
    17             optimizer.zero_grad()
    18             loss.backward()
    19             optimizer.step()
    20         return model, loss
    21 
    22 x=torch.Tensor([1,2,3,4,5,6])
    23 y=x+torch.rand(6)
    24 print(x)
    25 print(y)
    26 model=LR()
    27 print(list(model.parameters()))
    28 
    29 inputs=torch.unsqueeze(x,dim=1)
    30 target=torch.unsqueeze(y,dim=1)
    31 criterion=nn.MSELoss()
    32 optimizer = optim.SGD(model.parameters(), lr=1e-3)
    33 
    34 new_model,loss=model.train(inputs=inputs,target=target,criterion=criterion,optimizer=optimizer,epoches=10000)
    35 print(list(new_model.parameters()))
    36 print(loss.item())

     最后%一下来自华科的陈巨。

  • 相关阅读:
    周二
    周末
    简单I/O
    格式输出(1)
    c语言—变量
    水仙花数
    控制语句—循环语句
    mysql6数据库安装与配置
    如何解决Tomcat端口号被占用
    eclipse配置tomcat详细步骤
  • 原文地址:https://www.cnblogs.com/St-Lovaer/p/13678610.html
Copyright © 2011-2022 走看看