【模块缝合】【2022 TPAMI】External Attention, 外部注意力, 类似字典

2024-06-18 1416阅读

文章目录

    • 简介:
      • external-attention 结构图
      • 使用方式:大概在最后一层
      • 代码:
        • 官方代码:
        • 一个集成 模块的代码仓库的代码(非官方):

          paper: https://arxiv.org/pdf/2105.02358v2

          code: https://paperswithcode.com/paper/beyond-self-attention-external-attention


          简介:

          摘要:

          注意力机制,尤其是自注意力,在视觉任务的深度特征表示中发挥着越来越重要的作用。Self-attention 通过使用跨所有位置的成对相似性计算特征的加权和来更新每个位置的特征,以捕获单个样本中的长期依赖性。然而,self-attention 具有二次复杂度,忽略了不同样本之间的潜在相关性。本文提出了一种新的注意力机制,我们称之为外部注意力,基于两个外部的、小的、可学习的、共享的记忆,可以通过简单地使用两个级联线性层和两个归一化层轻松实现;它方便地取代了现有流行架构中的自我注意。外部注意力具有线性复杂度,隐含地考虑了所有数据样本之间的相关性。我们进一步将多头机制纳入外部注意力,为图像分类提供全 MLP 架构、外部注意力 MLP (EAMLP)。在图像分类、目标检测、语义分割、实例分割、图像生成和点云分析上的大量实验表明,我们的方法提供了与自我注意机制及其一些变体相当或更好的结果,计算和内存成本要低得多

          结论:

          本文介绍了外部注意力,这是一种新颖但有效的注意力机制,可用于各种视觉任务。外部注意力中采用的两个外部存储器单元可以被视为整个数据集的字典,并且能够在降低计算成本的同时学习更多具有代表性的输入特征。我们希望外部注意力将激发实际应用和研究其在 NLP 等其他领域的使用。


          external-attention 结构图

          【模块缝合】【2022 TPAMI】External Attention, 外部注意力, 类似字典

          The computational complexity of external attention is O(dSN ); as d and S are hyper-parameters, the proposed algorithm is linear in the number of pixels. In fact, we find that a small S, e.g. 64, works well in experiments. Thus, external attention is much more efficient than selfattention, allowing its direct application to large-scale inputs.

          使用方式:大概在最后一层

          【模块缝合】【2022 TPAMI】External Attention, 外部注意力, 类似字典

          代码:

          【模块缝合】【2022 TPAMI】External Attention, 外部注意力, 类似字典

          官方代码:

          # from: https://github.com/MenghaoGuo/EANet/blob/main/model_torch.py
          class External_attention(nn.Module):
              '''
              Arguments:
                  c (int): The input and output channel number.
              '''
              def __init__(self, c):
                  super(External_attention, self).__init__()
                  
                  self.conv1 = nn.Conv2d(c, c, 1)
                  self.k = 64
                  self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)
                  self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
                  self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        
                  
                  self.conv2 = nn.Sequential(
                      nn.Conv2d(c, c, 1, bias=False),
                      norm_layer(c))        
                  
                  for m in self.modules():
                      if isinstance(m, nn.Conv2d):
                          n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                          m.weight.data.normal_(0, math.sqrt(2. / n))
                      elif isinstance(m, nn.Conv1d):
                          n = m.kernel_size[0] * m.out_channels
                          m.weight.data.normal_(0, math.sqrt(2. / n))
                      elif isinstance(m, _BatchNorm):
                          m.weight.data.fill_(1)
                          if m.bias is not None:
                              m.bias.data.zero_()
              def forward(self, x):
                  idn = x
                  x = self.conv1(x)
                  b, c, h, w = x.size()
                  n = h*w
                  x = x.view(b, c, h*w)   # b * c * n 
                  attn = self.linear_0(x) # b, k, n
                  attn = F.softmax(attn, dim=-1) # b, k, n
                  attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, n
                  x = self.linear_1(attn) # b, c, n
                  x = x.view(b, c, h, w)
                  x = self.conv2(x)
                  x = x + idn
                  x = F.relu(x)
                  return x
          

          实现多头 注意力:

          【模块缝合】【2022 TPAMI】External Attention, 外部注意力, 类似字典

          【模块缝合】【2022 TPAMI】External Attention, 外部注意力, 类似字典

          官方代码:

          # from: https://github.com/MenghaoGuo/EANet/blob/main/multi_head_attention_torch.py
          class Attention(nn.Module):
              def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
                  super().__init__()
                  self.num_heads = num_heads
                  assert dim % num_heads == 0 
                  self.coef = 4
                  self.trans_dims = nn.Linear(dim, dim * self.coef)        
                  self.num_heads = self.num_heads * self.coef
                  self.k = 256 // self.coef
                  self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)
                  self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)
                  
                  self.attn_drop = nn.Dropout(attn_drop)        
                  self.proj = nn.Linear(dim * self.coef, dim)
                  self.proj_drop = nn.Dropout(proj_drop)
              def forward(self, x):
                  B, N, C = x.shape
                  x = self.trans_dims(x) # B, N, C 
                  x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
                  
                  attn = self.linear_0(x)
                  attn = attn.softmax(dim=-2)
                  attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))
                  attn = self.attn_drop(attn)
                  x = self.linear_1(attn).permute(0,2,1,3).reshape(B, N, -1)
                  
                  x = self.proj(x)
                  x = self.proj_drop(x)
                  return x
          

          一个集成 模块的代码仓库的代码(非官方):

          # from: https://github.com/xmu-xiaoma666/External-Attention-pytorch/blob/master/model/attention/ExternalAttention.py
          import numpy as np
          import torch
          from torch import nn
          from torch.nn import init
          class ExternalAttention(nn.Module):
              def __init__(self, d_model,S=64):
                  super().__init__()
                  self.mk=nn.Linear(d_model,S,bias=False)
                  self.mv=nn.Linear(S,d_model,bias=False)
                  self.softmax=nn.Softmax(dim=1)
                  self.init_weights()
              def init_weights(self):
                  for m in self.modules():
                      if isinstance(m, nn.Conv2d):
                          init.kaiming_normal_(m.weight, mode='fan_out')
                          if m.bias is not None:
                              init.constant_(m.bias, 0)
                      elif isinstance(m, nn.BatchNorm2d):
                          init.constant_(m.weight, 1)
                          init.constant_(m.bias, 0)
                      elif isinstance(m, nn.Linear):
                          init.normal_(m.weight, std=0.001)
                          if m.bias is not None:
                              init.constant_(m.bias, 0)
              def forward(self, queries):
                  attn=self.mk(queries) #bs,n,S
                  attn=self.softmax(attn) #bs,n,S
                  attn=attn/torch.sum(attn,dim=2,keepdim=True) #bs,n,S
                  out=self.mv(attn) #bs,n,d_model
                  return out
          if __name__ == '__main__':
              input=torch.randn(50,49,512)
              ea = ExternalAttention(d_model=512,S=8)
              output=ea(input)
              print(output.shape)
              
          
VPS购买请点击我

免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

目录[+]