zoukankan      html  css  js  c++  java
  • Pytorch LinearModel

    import torch


    x_data = torch.Tensor([[1.0],[2.0],[3.0]])
    y_data = torch.Tensor([[2.0],[4.0],[6.0]])


    class LinearModel(torch.nn.Module):
    def __init__(self):
    super(LinearModel,self).__init__()
    self.linear = torch.nn.Linear(1,1)

    def forward(self,x):
    y_pred = self.linear(x)
    return y_pred


    model = LinearModel()
    print(model)
    params = list(model.parameters())
    print(params)
    criterion = torch.nn.MSELoss(size_average=False)
    #optimizer = torch.optim.SGD(model.parameters(),lr = 0.01)
    #optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
    #optimizer = torch.optim.Adagrad(model.parameters(),lr= 0.01)
    #optimizer = torch.optim.ASGD(model.parameters(), lr=0.01)
    #optimizer = torch.optim.LBFGS(model.parameters(),lr = 0.01)
    #optimizer = torch.optim.RMSprop(model.parameters(),lr = 0.01)
    optimizer = torch.optim.Rprop(model.parameters(),lr = 0.01)
    for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch, loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step(closure=float)
    print('w = ',model.linear.weight.item())
    print('b = ',model.linear.bias.item())

    x_test = torch.Tensor([[4.0]])
    y_test = model(x_test)
    print('y_pred = ', y_test.data)

  • 相关阅读:
    IE设置cookie问题。
    正则表达式。
    Git和SVN区别
    点滴MarkDown~
    浏览器页面是否缩放问题。
    我理解的灰度发布。
    有衬线字体和无衬线字体
    移动开发规范
    Thunderbird 如何接收 Foxmail 发出的会议邀请。。
    移动端 CSS sprites 的 background-size 如何计算的。
  • 原文地址:https://www.cnblogs.com/songyuejie/p/14941599.html
Copyright © 2011-2022 走看看