zoukankan      html  css  js  c++  java
  • Learning Memory-guided Normality代码学习笔记

    Learning Memory-guided Normality代码学习笔记

    记忆模块核心

    Memory部分的核心在于以下定义Memory类的部分。

    class Memory(nn.Module):
        def __init__(self, memory_size, feature_dim, key_dim,  temp_update, temp_gather):
            super(Memory, self).__init__()
            # Constants
            self.memory_size = memory_size
            self.feature_dim = feature_dim
            self.key_dim = key_dim
            self.temp_update = temp_update
            self.temp_gather = temp_gather
            
        def hard_neg_mem(self, mem, i):
            similarity = torch.matmul(mem,torch.t(self.keys_var))
            similarity[:,i] = -1
            _, max_idx = torch.topk(similarity, 1, dim=1)
            
            
            return self.keys_var[max_idx]
        
        def random_pick_memory(self, mem, max_indices):
            
            m, d = mem.size()
            output = []
            for i in range(m):
                flattened_indices = (max_indices==i).nonzero()
                a, _ = flattened_indices.size()
                if a != 0:
                    number = np.random.choice(a, 1)
                    output.append(flattened_indices[number, 0])
                else:
                    output.append(-1)
                
            return torch.tensor(output)
        
        def get_update_query(self, mem, max_indices, update_indices, score, query, train):
            
            m, d = mem.size()
            if train:
                query_update = torch.zeros((m,d)).cuda()
                # random_update = torch.zeros((m,d)).cuda()
                for i in range(m):
                    idx = torch.nonzero(max_indices.squeeze(1)==i)
                    a, _ = idx.size()
                    if a != 0:
                        query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
                    else:
                        query_update[i] = 0 
            
           
                return query_update 
        
            else:
                query_update = torch.zeros((m,d)).cuda()
                for i in range(m):
                    idx = torch.nonzero(max_indices.squeeze(1)==i)
                    a, _ = idx.size()
                    if a != 0:
                        query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
                    else:
                        query_update[i] = 0 
                
                return query_update
    
        def get_score(self, mem, query):
            bs, h,w,d = query.size()
            m, d = mem.size()
            
            score = torch.matmul(query, torch.t(mem))# b X h X w X m
            score = score.view(bs*h*w, m)# (b X h X w) X m
            
            score_query = F.softmax(score, dim=0)
            score_memory = F.softmax(score,dim=1)
            
            return score_query, score_memory
        
        def forward(self, query, keys, train=True):
    
            batch_size, dims,h,w = query.size() # b X d X h X w
            query = F.normalize(query, dim=1)
            query = query.permute(0,2,3,1) # b X h X w X d
            
            #train
            if train:
                #losses
                separateness_loss, compactness_loss = self.gather_loss(query,keys, train)
                # read
                updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
                #update
                updated_memory = self.update(query, keys, train)
                
                return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss
            
            #test
            else:
                # loss
                compactness_loss, query_re, top1_keys, keys_ind = self.gather_loss(query,keys, train)
                
                # read
                updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
                
                #update
                updated_memory = keys
                    
                   
                return updated_query, updated_memory, softmax_score_query, softmax_score_memory, query_re, top1_keys,keys_ind, compactness_loss
            
            
        
        def update(self, query, keys,train):
            
            batch_size, h,w,dims = query.size() # b X h X w X d 
            
            softmax_score_query, softmax_score_memory = self.get_score(keys, query)
            
            query_reshape = query.contiguous().view(batch_size*h*w, dims)
            
            _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
            _, updating_indices = torch.topk(softmax_score_query, 1, dim=0)
            
            if train:
                 
                query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape,train)
                updated_memory = F.normalize(query_update + keys, dim=1)
            
            else:
                query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape, train)
                updated_memory = F.normalize(query_update + keys, dim=1)
            
            return updated_memory.detach()
            
            
        def pointwise_gather_loss(self, query_reshape, keys, gathering_indices, train):
            n,dims = query_reshape.size() # (b X h X w) X d
            loss_mse = torch.nn.MSELoss(reduction='none')
            
            pointwise_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
                    
            return pointwise_loss
            
        def gather_loss(self,query, keys, train):
            batch_size, h,w,dims = query.size() # b X h X w X d
            if train:
                loss = torch.nn.TripletMarginLoss(margin=1.0)
                loss_mse = torch.nn.MSELoss()
                softmax_score_query, softmax_score_memory = self.get_score(keys, query)
            
                query_reshape = query.contiguous().view(batch_size*h*w, dims)
            
                _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1)
            
                #1st, 2nd closest memories
                pos = keys[gathering_indices[:,0]]
                neg = keys[gathering_indices[:,1]]
                top1_loss = loss_mse(query_reshape, pos.detach())
                gathering_loss = loss(query_reshape,pos.detach(), neg.detach())
                
                return gathering_loss, top1_loss
            
                
            else:
                loss_mse = torch.nn.MSELoss()
            
                softmax_score_query, softmax_score_memory = self.get_score(keys, query)
            
                query_reshape = query.contiguous().view(batch_size*h*w, dims)
            
                _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
            
                gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
                
                return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0]
                
            
            
        
        def read(self, query, updated_memory):
            batch_size, h,w,dims = query.size() # b X h X w X d
    
            softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query)
    
            query_reshape = query.contiguous().view(batch_size*h*w, dims)
            
            concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
            updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
            updated_query = updated_query.view(batch_size, h, w, 2*dims)
            updated_query = updated_query.permute(0,3,1,2)
            
            return updated_query, softmax_score_query, softmax_score_memory
        
    

    Update过程

    调用get_update_query(self, mem, max_indices, update_indices, score, query, train)函数计算(query\_ dpdate= sum_{k in U_{t}^M} v_t^{'k,m}q_t^k)

    然后计算(f(P^m+query_dpdate))

    文中对f的描述为L2正则。

    看一下get_update_query函数的定义:

        def get_update_query(self, mem, max_indices, update_indices, score, query, train):
            
            m, d = mem.size()
            if train:
                query_update = torch.zeros((m,d)).cuda()
                # random_update = torch.zeros((m,d)).cuda()
                for i in range(m):
                    idx = torch.nonzero(max_indices.squeeze(1)==i)
                    a, _ = idx.size()
                    if a != 0:
                        query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
                    else:
                        query_update[i] = 0 
            
           
                return query_update 
        
            else:
                query_update = torch.zeros((m,d)).cuda()
                for i in range(m):
                    idx = torch.nonzero(max_indices.squeeze(1)==i)
                    a, _ = idx.size()
                    if a != 0:
                        query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
                    else:
                        query_update[i] = 0 
                
                return query_update
    
    

    在定义中,我们需要看到(v_t^{'k,m})的计算。代码是通过(score[idx,i] / torch.max(score[:,i])实现的,进一步,我们需要查看(v_t^{k,m})的计算过程。这个参数与(w)一样是权重,文中通过get_score函数计算权重,如下为此函数的定义:

        def get_score(self, mem, query):
            #计算权重$w_t^{k,m}$
            bs, h,w,d = query.size()
            m, d = mem.size()
            
            score = torch.matmul(query, torch.t(mem))# b X h X w X m
            score = score.view(bs*h*w, m)# (b X h X w) X m
            
            score_query = F.softmax(score, dim=0)
            score_memory = F.softmax(score,dim=1)
            
            return score_query, score_memory
    

    实现了文献中的权重计算

    image-20201202104454789

    Read过程

    def read(self, query, updated_memory):
            #Read部分
            batch_size, h,w,dims = query.size() # b X h X w X d
    
            softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query)
    
            query_reshape = query.contiguous().view(batch_size*h*w, dims)
            
            concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
            # 权重和memory获得加权均值
            updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
            # 进行拼接
            updated_query = updated_query.view(batch_size, h, w, 2*dims)
            updated_query = updated_query.permute(0,3,1,2)
            
            return updated_query, softmax_score_query, softmax_score_memory
        
    

    核心部分在代码中给出了注释。

    forward过程

    separateness_loss, compactness_loss = self.gather_loss(query,keys, train)
    # read
    updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
    #update
    updated_memory = self.update(query, keys, train)
    
    return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss
    

    分别调用update函数和read函数

    需要说明损失函数的定义,(L = L_{rec} + lambda _cL_{compact}+ lambda _sL_{separate})

    代码中通过gather_loss函数实现。

    def gather_loss(self,query, keys, train):
        batch_size, h,w,dims = query.size() # b X h X w X d
        if train:
            loss = torch.nn.TripletMarginLoss(margin=1.0)
            # 计算Feature separateness loss的主要函数
            loss_mse = torch.nn.MSELoss()
            # 计算均方差损失
            softmax_score_query, softmax_score_memory = self.get_score(keys, query)
        
            query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
            _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1)
        
            #1st, 2nd closest memories
            pos = keys[gathering_indices[:,0]]
            neg = keys[gathering_indices[:,1]]
            top1_loss = loss_mse(query_reshape, pos.detach())
            gathering_loss = loss(query_reshape,pos.detach(), neg.detach())
            
            return gathering_loss, top1_loss
        
            
        else:
            loss_mse = torch.nn.MSELoss()
        
            softmax_score_query, softmax_score_memory = self.get_score(keys, query)
        
            query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
            _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
        
            gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
            
            return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0]
            
    
  • 相关阅读:
    day25:接口类和抽象类
    vue1
    How the weather influences your mood?
    机器学习实验方法与原理
    How human activities damage the environment
    Slow food
    Brief Introduction to Esports
    Massive open online course (MOOC)
    Online learning in higher education
    Tensorflow Dataset API
  • 原文地址:https://www.cnblogs.com/pteromyini/p/14402629.html
Copyright © 2011-2022 走看看