zoukankan      html  css  js  c++  java
  • 注意力增强卷积 代码解读

    原论文 Attention Augmented Convolutional Networks

    代码来源 leaderj1001/Attention-Augmented-Conv2d

    导入模块&cuda加载

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    

    初始化&forward

    class AugmentedConv(nn.Module):
        def __init__(self, in_channels, out_channels, kernel_size, dk, dv, Nh, shape=0, relative=False, stride=1):
            super(AugmentedConv, self).__init__()
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.kernel_size = kernel_size
            self.dk = dk
            self.dv = dv
            self.Nh = Nh
            self.shape = shape
            self.relative = relative  # 是否加入位置编码
            self.stride = stride
            self.padding = (self.kernel_size - 1) // 2
    
            assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
            assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
            assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
            assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."
    
            # 这里要减去 dv,因为 conv_out 的输出要和 attn_out 的输出合并
            self.conv_out = nn.Conv2d(self.in_channels, self.out_channels - self.dv, self.kernel_size, stride=stride, padding=self.padding)
    
            # 这个卷积操作的目的就是得到 k, q, v, 注意卷积操作包含了计算 X * W_q, X * W_k, X * W_v 的过程
            self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size, stride=stride, padding=self.padding)
    
            # attention 的结果仍要作为特征层传入卷积层进行特征提取
            self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)
    
            if self.relative:  # 每个位置的w, h相对位置编码的可学习参数量均为 2 * [w or h] - 1
                self.key_rel_w = nn.Parameter(torch.randn((int(2 * self.shape - 1), dk // Nh), requires_grad=True))
                self.key_rel_h = nn.Parameter(torch.randn((int(2 * self.shape - 1), dk // Nh), requires_grad=True))
    
        def forward(self, x):
            """
            attention augmented conv 的 “主函数”
            :param x: 输入数据,形状为 (batch_size, in_channels, height, width)
            :return: 最终输出,形状为 (batch, out_channels, height, width)
            """
    
            # conv_out -> (batch_size, out_channels - dv, height, width)
            conv_out = self.conv_out(x)
            batch, _, height, width = conv_out.size()
    
            # flat_q, flat_k, flat_v -> (batch_size, Nh, height * width, dvh or dkh)
            # dvh = dv / Nh, dkh = dk / Nh
            # q, k, v -> (batch_size, Nh, height, width, dv or dk)
            flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
            logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
            if self.relative:
                h_rel_logits, w_rel_logits = self.relative_logits(q)
                try:
                    logits += h_rel_logits
                except:
                    print(h_rel_logits.shape)
                logits += w_rel_logits
            weights = F.softmax(logits, dim=-1)
    
            # attn_out -> (batch_size, Nh, height * width, dvh)
            attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
            attn_out = torch.reshape(attn_out, (batch, self.Nh, self.dv // self.Nh, height, width))
            # combine_heads_2d -> (batch_size, dv, height, width)
            attn_out = self.combine_heads_2d(attn_out)
            attn_out = self.attn_out(attn_out)  # 将注意力运算结果作为特征层传入卷积层
            return torch.cat((conv_out, attn_out), dim=1)
    
    

    功能函数

        def compute_flat_qkv(self, x, dk, dv, Nh):
            """
            计算 q, k, v 以及每个 head 的 q, k, v
            :param x: 输入数据,形状为 (batch_size, in_channels, height, width)
            :param dk: q, k 的维度
            :param dv: v 的维度
            :param Nh: 有多少个 head
            :return: flat_q, flat_k, flat_v, q, k, v
            """
            qkv = self.qkv_conv(x)  # 利用卷积操作求 q, k, v, 包含了计算 X * W_q, X * W_k, X * W_v 的过程
            N, _, H, W = qkv.size()
            q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)  # 将卷积输出按 channel 划分为 q, k, v
            # 将single head 改为 multi-head
            q = self.split_heads_2d(q, Nh)
            k = self.split_heads_2d(k, Nh)
            v = self.split_heads_2d(v, Nh)
    
            dkh = dk // Nh
            q *= dkh ** -0.5
            # 得到每个 head 的 q, k, v
            flat_q = torch.reshape(q, (N, Nh, dk // Nh, H * W))
            flat_k = torch.reshape(k, (N, Nh, dk // Nh, H * W))
            flat_v = torch.reshape(v, (N, Nh, dv // Nh, H * W))
            return flat_q, flat_k, flat_v, q, k, v
    
        def split_heads_2d(self, x, Nh):
            """
            划分 head
            :param x: q or k or v
            :param Nh: head 的数量,必须要能整除 q, k, v 的 channel 维度数
            :return: reshape 后的 q, k, v
            """
            batch, channels, height, width = x.size()
            ret_shape = (batch, Nh, channels // Nh, height, width)
            split = torch.reshape(x, ret_shape)
            return split
    
        def combine_heads_2d(self, x):
            """
            将所有 head 的输出组合到一起
            :param x: 包含所有 head 的输出
            :return: 组合后的输出
            """
            batch, Nh, dv, H, W = x.size()
            ret_shape = (batch, Nh * dv, H, W)
            return torch.reshape(x, ret_shape)
    

    位置编码

        def relative_logits(self, q):
            """
            计算相对位置编码
            :param q: q
            :return: h 和 w 的位置编码
            """
            B, Nh, dk, H, W = q.size()
            # q -> (B, Nh, H, W, dk)
            q = torch.transpose(q, 2, 4).transpose(2, 3)
            # 分别计算 w 与 h 的一维编码
            rel_logits_w = self.relative_logits_1d(q, self.key_rel_w, H, W, Nh, "w")
            rel_logits_h = self.relative_logits_1d(torch.transpose(q, 2, 3), self.key_rel_h, W, H, Nh, "h")
    
            return rel_logits_h, rel_logits_w
    
        def relative_logits_1d(self, q, rel_k, H, W, Nh, case):
            """
            计算一维位置编码
            :param q: q,维度为(B, Nh, H, W, dk)
            :param rel_k: 位置编码的可学习参数,形状为为 (2 * [w or h] - 1, dk // Nh)
            :param H: 输入特征高度
            :param W: 输入特征宽度
            :param Nh: head 数量
            :param case: 区分 w 还是 h 的位置编码
            :return: 位置编码,形状为 (B, Nh, H * W, H * W)
            """
            # 使用爱因斯坦求和约定,实现批量矩阵乘法
            rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k)
            # 因为是一维位置编码 (w or h),所以另一个维度用不上
            rel_logits = torch.reshape(rel_logits, (-1, Nh * H, W, 2 * W - 1))
            # 加入位置信息
            rel_logits = self.rel_to_abs(rel_logits)
            # 下面的操作都是为了最后能产生形状为 (B, Nh, H * W, H * W) 的输出,以便于与 logit 相加
            # 详见 forward 函数第 18 行
            rel_logits = torch.reshape(rel_logits, (-1, Nh, H, W, W))
            rel_logits = torch.unsqueeze(rel_logits, dim=3)
            rel_logits = rel_logits.repeat((1, 1, 1, H, 1, 1))
    
            if case == "w":
                rel_logits = torch.transpose(rel_logits, 3, 4)
            elif case == "h":
                rel_logits = torch.transpose(rel_logits, 2, 4).transpose(4, 5).transpose(3, 5)
            # 改变形状以便于与 logit 相加
            rel_logits = torch.reshape(rel_logits, (-1, Nh, H * W, H * W))
            return rel_logits
    
        def rel_to_abs(self, x):
            """
            相对 to 绝对,在位置编码中加入绝对位置信息
            :param x: 原始位置编码,形状为 (B, Nh * H, W, 2 * W - 1)
            :return: 位置编码,形状为 (B, Nh * H, W, W)
            """
            B, Nh, L, _ = x.size()
            # '0' 即绝对位置信息,此后所有操作都是为了让同一 [行 or 列] 的每个点的位置编码的 '0' 出现的位置不同
            # 在最后一个维度的末尾,即每隔 2L - 1 的位置加入 0,
            # 这就是为什么 key_rel_[w or h],即可学习参数有 2 * [w or h] - 1 个
            col_pad = torch.zeros((B, Nh, L, 1)).to(x)
            x = torch.cat((x, col_pad), dim=3)
    
            # 每个 head 加入 L - 1 个 0, 为了让每一 [行 or 列] 的 '0' 错位
            flat_x = torch.reshape(x, (B, Nh, L * 2 * L))
            flat_pad = torch.zeros((B, Nh, L - 1)).to(x)
            flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
            # 将 (L * 2 * L) + (L - 1) 个编码元素重新组织,使其形状为为 (L + 1, 2 * L - 1)
            # 目的是让 '0' 错位,这样每一 [行 or 列] 的点的位置编码中 '0' 出现的位置不一样
            # 相当于嵌入了绝对位置信息
            final_x = torch.reshape(flat_x_padded, (B, Nh, L + 1, 2 * L - 1))
            # reshape 以便于后续操作
            final_x = final_x[:, :, :L, L - 1:]
            return final_x
    

    使用示例

    if __name__ == "__main__":
        tmp = torch.randn((4, 3, 32, 32)).to(device)
        augmented_conv1 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4,
                                        relative=True, stride=2, shape=16).to(device)
        conv_out1 = augmented_conv1(tmp)
        print(conv_out1.shape)
    
        for name, param in augmented_conv1.named_parameters():
            print('parameter name: ', name)
    
        augmented_conv2 = AugmentedConv(in_channels=3, out_channels=20, kernel_size=3, dk=40, dv=4, Nh=4,
                                        relative=True, stride=1, shape=32).to(device)
        conv_out2 = augmented_conv2(tmp)
        print(conv_out2.shape)
    
  • 相关阅读:
    ORACLE AWR 和 ASH
    11g RAC R2 日常巡检--Grid
    Linux中重命名文件
    Xshell4连接Linux后 win快捷键锁屏
    vim 删除临时文件
    shell--read命令
    shell基础篇(一)从hello world开始
    ORACLE--分区表数据清理
    Shell—学习之心得
    awk 手册--【转载】
  • 原文地址:https://www.cnblogs.com/wang-haoran/p/14135944.html
Copyright © 2011-2022 走看看