zoukankan      html  css  js  c++  java
  • 1105pytorch实践

    pytorch实现单维度线性回归

    代码

    import torch
    
    x_data=torch.Tensor([[1.0],[2.0],[3.0]])
    y_data=torch.Tensor([[2.0],[4.0],[6.0]])
    
    class LinearModel(torch.nn.Module):
        def __init__(self):
            super(LinearModel,self).__init__()
            self.linear=torch.nn.Linear(1,1)
    
        def forward(self,x):
            y_pred=self.linear(x)
            return y_pred
    
    model=LinearModel()
    
    criterion=torch.nn.MSELoss(size_average=False)
    optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
    
    for epoch in range(1000):
        y_pred=model(x_data)
        loss=criterion(y_pred,y_data)
        print(epoch,loss.item())
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print('w=',model.linear.weight.item())
    print('b=',model.linear.bias.item())
    
    x_test=torch.Tensor([[4.0]])
    y_test=model(x_test)
    print('y_pred=',y_test.data)

    结果

    logist逻辑回归分类 

    代码

    import torch
    import torch.nn.functional as F
    
    x_data=torch.Tensor([[1.0],[2.0],[3.0]])
    y_data=torch.Tensor([[0],[0],[1]])
    
    class LogisticRegressionModel(torch.nn.Module):
        def __init__(self):
            super(LogisticRegressionModel,self).__init__()
            self.linear=torch.nn.Linear(1,1)
    
        def forward(self,x):
            y_pred=torch.sigmoid(self.linear(x))
            return y_pred
    
    model=LogisticRegressionModel()
    
    criterion=torch.nn.BCELoss(reduction='sum')
    optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
    
    for epoch in range(1000):
        y_pred=model(x_data)
        loss=criterion(y_pred,y_data)
        print(epoch,loss.item())
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print('w=',model.linear.weight.item())
    print('b=',model.linear.bias.item())
    
    x_test=torch.Tensor([[4.0]])
    y_test=model(x_test)
    print('y_pred=',y_test.data)
    
    x_test=torch.Tensor([[2.5]])
    y_test=model(x_test)
    print('y_pred=',y_test.data)
    
    import numpy as np
    import matplotlib.pyplot as plt
    
    x=np.linspace(0,10,200)
    x_t=torch.Tensor(x).view((200,1))
    y_t=model(x_t)
    y=y_t.data.numpy()
    plt.plot(x,y)
    plt.plot([0,10],[0.5,0.5],c='r')
    plt.xlabel('Hours')
    plt.ylabel('Probablity of Pass')
    plt.grid()
    plt.show()

    结果

     多维度输入分类

    代码

    import torch
    import numpy as np
    
    xy=np.loadtxt('diabetes.csv.gz',delimiter=',',dtype=np.float32)
    x_data=torch.from_numpy(xy[:,:-1])
    y_data=torch.from_numpy(xy[:,[-1]])
    
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model,self).__init__()
            self.linear1=torch.nn.Linear(8,6)
            self.linear2 = torch.nn.Linear(6,4)
            self.linear3 = torch.nn.Linear(4,1)
            self.sigmoid=torch.nn.Sigmoid()
            self.activate=torch.nn.ReLU()
    
        def forward(self,x):
            x=self.sigmoid(self.linear1(x))
            x = self.sigmoid(self.linear2(x))
            x = self.sigmoid(self.linear3(x))
            return x
    
    model=Model()
    
    criterion=torch.nn.BCELoss(reduction='sum')
    optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
    
    for epoch in range(100):
        #forward
        y_pred=model(x_data)
        loss=criterion(y_pred,y_data)
        print(epoch,loss.item())
    
        #backward
        optimizer.zero_grad()
        loss.backward()
    
        #update
        optimizer.step()

    结果

     总结

    pytorch相对于tensorflow来说代码更加方便,而且模块化的效果很好,继续学习,另外力推B站pytorch视频:https://www.bilibili.com/video/BV1Y7411d7Ys?p=8&spm_id_from=pageDriver

  • 相关阅读:
    mybatis plus使用redis作为二级缓存
    netty无缝切换rabbitmq、activemq、rocketmq实现聊天室单聊、群聊功能
    netty使用EmbeddedChannel对channel的出入站进行单元测试
    记jdk1.8中hashmap的tableSizeFor方法
    Cannot find class: BaseResultMap
    windows下远程访问Redis,windows Redis绑定ip无效,Redis设置密码无效,Windows Redis 配置不生效,Windows Redis requirepass不生效,windows下远程访问redis的配置
    学习记录
    eclipse的注释
    转:聊聊同步、异步、阻塞与非阻塞
    点滴笔记(二):利用JS对象把值传到后台
  • 原文地址:https://www.cnblogs.com/xiaofengzai/p/15515563.html
Copyright © 2011-2022 走看看