zoukankan      html  css  js  c++  java
  • 1105pytorch实践

    pytorch实现单维度线性回归

    代码

    import torch
    
    x_data=torch.Tensor([[1.0],[2.0],[3.0]])
    y_data=torch.Tensor([[2.0],[4.0],[6.0]])
    
    class LinearModel(torch.nn.Module):
        def __init__(self):
            super(LinearModel,self).__init__()
            self.linear=torch.nn.Linear(1,1)
    
        def forward(self,x):
            y_pred=self.linear(x)
            return y_pred
    
    model=LinearModel()
    
    criterion=torch.nn.MSELoss(size_average=False)
    optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
    
    for epoch in range(1000):
        y_pred=model(x_data)
        loss=criterion(y_pred,y_data)
        print(epoch,loss.item())
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print('w=',model.linear.weight.item())
    print('b=',model.linear.bias.item())
    
    x_test=torch.Tensor([[4.0]])
    y_test=model(x_test)
    print('y_pred=',y_test.data)

    结果

    logist逻辑回归分类 

    代码

    import torch
    import torch.nn.functional as F
    
    x_data=torch.Tensor([[1.0],[2.0],[3.0]])
    y_data=torch.Tensor([[0],[0],[1]])
    
    class LogisticRegressionModel(torch.nn.Module):
        def __init__(self):
            super(LogisticRegressionModel,self).__init__()
            self.linear=torch.nn.Linear(1,1)
    
        def forward(self,x):
            y_pred=torch.sigmoid(self.linear(x))
            return y_pred
    
    model=LogisticRegressionModel()
    
    criterion=torch.nn.BCELoss(reduction='sum')
    optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
    
    for epoch in range(1000):
        y_pred=model(x_data)
        loss=criterion(y_pred,y_data)
        print(epoch,loss.item())
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print('w=',model.linear.weight.item())
    print('b=',model.linear.bias.item())
    
    x_test=torch.Tensor([[4.0]])
    y_test=model(x_test)
    print('y_pred=',y_test.data)
    
    x_test=torch.Tensor([[2.5]])
    y_test=model(x_test)
    print('y_pred=',y_test.data)
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    x=np.linspace(0,10,200)
    x_t=torch.Tensor(x).view((200,1))
    y_t=model(x_t)
    y=y_t.data.numpy()
    plt.plot(x,y)
    plt.plot([0,10],[0.5,0.5],c='r')
    plt.xlabel('Hours')
    plt.ylabel('Probablity of Pass')
    plt.grid()
    plt.show()

    结果

     多维度输入分类

    代码

    import torch
    import numpy as np
    
    xy=np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32)
    x_data=torch.from_numpy(xy[:,:-1])
    y_data=torch.from_numpy(xy[:,[-1]])
    
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model,self).__init__()
            self.linear1=torch.nn.Linear(8,6)
            self.linear2 = torch.nn.Linear(6,4)
            self.linear3 = torch.nn.Linear(4,1)
            self.sigmoid=torch.nn.Sigmoid()
            self.activate=torch.nn.ReLU()
    
        def forward(self,x):
            x=self.sigmoid(self.linear1(x))
            x = self.sigmoid(self.linear2(x))
            x = self.sigmoid(self.linear3(x))
            return x
    
    model=Model()
    
    criterion=torch.nn.BCELoss(reduction='sum')
    optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
    
    for epoch in range(100):
        #forward
        y_pred=model(x_data)
        loss=criterion(y_pred,y_data)
        print(epoch,loss.item())
    
        #backward
        optimizer.zero_grad()
        loss.backward()
    
        #update
        optimizer.step()

    结果

     总结

    pytorch相对于tensorflow来说代码更加方便,而且模块化的效果很好,继续学习,另外力推B站pytorch视频:https://www.bilibili.com/video/BV1Y7411d7Ys?p=8&spm_id_from=pageDriver

  • 相关阅读:
    MySQL索引背后的数据结构及算法原理 [转]
    5.5下对DDL操作提速的测试
    由浅入深理解索引的实现(2) [转]
    由浅入深理解索引的实现(1) [转]
    两个比较有用的字符串函数
    在慢查询里保留注释部分
    想在Innodb表上做OPTIMIZE操作?先等等看再说!
    Win CE和smartphone和pocket pc和windows mobile比较(zt)
    学习笔记(配置SQL Server 2005允许远程连接)
    配置程序集的版本策略(zt)
  • 原文地址:https://www.cnblogs.com/xiaofengzai/p/15515563.html
Copyright © 2011-2022 走看看