zoukankan      html  css  js  c++  java
  • [学习笔记] CNN与RNN方法结合

    CNN与RNN的结合

    问题

    前几天学习了RNN的推导以及代码,那么问题来了,能不能把CNN和RNN结合起来,我们通过CNN提取的特征,能不能也将其看成一个序列呢?答案是可以的。

    但是我觉得一般直接提取的特征喂给哦RNN训练意义是不大的,因为RNN擅长处理的是不定长的序列,也就是说,seq size是不确定的,但是一般图像特征的神经元数量都是定的,这个时候再接个rnn说实话意义不大,除非设计一种结构可以让网络不定长输出。(我的一个简单想法就是再设计一条之路去学习一个神经元权重mask,按照规则过滤掉一些神经元,然后丢进rnn或者lstm训练)

    如何实现呢

    import torch
    import torch.nn as nn
    from torchsummary import summary
    from torchvision import datasets,transforms
    import torch.optim as optim
    from tqdm import tqdm
    class Model(nn.Module):
        def __init__(self):
            super(Model,self).__init__()
            
            self.feature_extractor = nn.Sequential(
                nn.Conv2d(1,16,kernel_size = 3,stride=2),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16,64,kernel_size = 3,stride=2),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.Conv2d(64,128,kernel_size = 3,stride=2),
                nn.BatchNorm2d(128),
                nn.ReLU(),
            )
            self.rnn = nn.RNN(128,256,2) # input_size,output_size,hidden_num
            self.h0 = torch.zeros(2,32,256) # 层数 batchsize hidden_dim
            self.predictor = nn.Linear(4*256,10)
        def forward(self,x):
            x = self.feature_extractor(x) # (-1,128,2,2),4个神经元,128维度
            x,ht = self.rnn(x.permute(3,4,0,1).contiguous().view(4,-1,128),self.h0) # (h*w,batch_size,hidden_dim)
            
            x = self.predictor(x.view(-1,256*4))
            return x
    
    if __name__ == "__main__":
        model = Model()
        #summary(model,(1,28,28),device = "cpu")
        loss_fn = nn.CrossEntropyLoss()
        train_dataset = datasets.MNIST(root="./data/",train = True,transform = transforms.ToTensor(),download = True)
        test_dataset = datasets.MNIST(root="./data/",train = False,transform = transforms.ToTensor(),download = True)
        
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=32,
                                               shuffle=True)
    
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                batch_size=128,
                                                shuffle=False)
        optimizer = optim.Adam(model.parameters(),lr = 1e-3)
        print(len(train_loader))
        for epoch in range(100):
            epoch_loss = 0.
            for x,target in train_loader:
                #print(x.size())
                y = model(x)
                loss = loss_fn(y,target)
                epoch_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            print("epoch : {} and loss is : {}".format(epoch +1,epoch_loss))
        torch.save(model.state_dict(),"rnn_cnn.pth")
    

    上面代码可以看出我已经规定了RNN输入神经元的个数,所以肯定是定长的输入,我训练之后是可以收敛的。

    对于不定长,其实还是没办法改变每个batch的seq len,因为规定的一定是最长的seq len,所以没办法做到真正的不定长。所以我能做的就是通过支路学习一个权重作用到原来的feature上去,这个权重是0-1权重,其实这样就可以达到效果了。

    import torch
    import torch.nn as nn
    from torchsummary import summary
    from torchvision import datasets,transforms
    import torch.optim as optim
    import torch.nn.functional as F
    from tqdm import tqdm
    class Model(nn.Module):
        def __init__(self):
            super(Model,self).__init__()
            
            self.feature_extractor = nn.Sequential(
                nn.Conv2d(1,16,kernel_size = 3,stride=2),
                nn.BatchNorm2d(16),
                nn.ReLU6(),
                nn.Conv2d(16,64,kernel_size = 3,stride=2),
                nn.BatchNorm2d(64),
                nn.ReLU6(),
                nn.Conv2d(64,128,kernel_size = 3,stride=2),
                nn.BatchNorm2d(128),
                nn.ReLU6(),
            )
            self.attn = nn.Conv2d(128,1,kernel_size = 1)
            self.rnn = nn.RNN(128,256,2) # input_size,output_size,hidden_num
            
            self.h0 = torch.zeros(2,32,256) # 层数 batchsize hidden_dim
            self.predictor = nn.Linear(4*256,10)
        def forward(self,x):
            x = self.feature_extractor(x) # (-1,128,2,2),4个神经元,128维度
            attn = F.relu(self.attn(x)) # (-1,1,2,2) -> (-1,4)
            x = x * attn
            #print(x.size()) 
            x,ht = self.rnn(x.permute(3,4,0,1).contiguous().view(4,-1,128),self.h0) # (h*w,batch_size,hidden_dim)
            #self.h0 = ht
            x = self.predictor(x.view(-1,256*4))
            return x
    
    if __name__ == "__main__":
        model = Model()
        #summary(model,(1,28,28),device = "cpu")
        #exit()
        loss_fn = nn.CrossEntropyLoss()
        train_dataset = datasets.MNIST(root="./data/",train = True,transform = transforms.ToTensor(),download = True)
        test_dataset = datasets.MNIST(root="./data/",train = False,transform = transforms.ToTensor(),download = True)
        
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=32,
                                               shuffle=True)
    
        test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                batch_size=128,
                                                shuffle=False)
        optimizer = optim.Adam(model.parameters(),lr = 1e-3)
        print(len(train_loader))
        for epoch in range(100):
            epoch_loss = 0.
            for x,target in train_loader:
                #print(x.size())
    
                y = model(x)
                
                loss = loss_fn(y,target)
                epoch_loss += loss.item()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            print("epoch : {} and loss is : {}".format(epoch +1,epoch_loss))
        torch.save(model.state_dict(),"rnn_cnn.pth")
    

    我自己训练了一下,后者要比前者收敛的快的多。

  • 相关阅读:
    HTML-利用CSS和JavaScript制作一个切换图片的网页
    HTML-★★★格式与布局fixed/absolute/relative/z-index/float★★★
    HTML-CSS样式表-★★★常用属性★★★及基本概念、分类、选择器
    HTML-★★★★★表单★★★★★
    HTML-图片热点、网页内嵌、网页拼接、快速切图
    HTML-常用标签与表格标签
    HTML-基础及一般标签
    C#-★结构体★
    C#-函数的传值与传址
    C#-★★函数★★
  • 原文地址:https://www.cnblogs.com/aoru45/p/11576023.html
Copyright © 2011-2022 走看看