整体流程:在给定一张输入图像后,1)特征向量提取:首先经过ResNet提取图像的最后一层特征图F。注意此处仅仅用了一层特征图,是因为后续计算复杂度原因,另外,由于仅用最后一层特征图,故对小目标检测不友好,这也是后续deformable detr改进的原因。 2)添加位置编码信息:经F拉平成一维张量并添加上位置编码信息得到I。3)Transformer中encoder部分4)Transformer中decoder部分,学习位置嵌入object queries。5)FFN部分:6)后续匈牙利匹配+损失计算。
Detr的内部逻辑如下:在mmdet/models/detector/single_stage.py。即首先提取图像特征向量,之后经过DetrHead来计算最终的损失。
mmdet中提取图像特征向量的config配置文件如下,可以发现用ResNet50并只提取了最后一层特征层,即out_indices=(3,)。关于内部原理参见我的博文:mmdet之backbone介绍。
本部分代码来自mmdet/models/dense_heads/detr_head.py。
mmdet中生成位置编码信息借助的是mask矩阵,所谓的mask就是为了统一批次大小而对图像进行了pad,被填充的部分在后续计算多头注意力时应该舍弃,故需要一个mask矩阵遮挡住,具体形状为[batch, h,w]这里先贴下生成mask的过程:
我这里简单贴下mask示意图:
在有了mask基础上[batch,256,h,w],注意此时的hw是原图大小的;而输入图像的经过resnet50下采样后hw已经变了,所以还需进一步将mask下采样成和图像特征向量一样的shape。代码如下:
后续便可以生成位置编码部分(mmdet/models/utils/position_encoding.py),该函数给masks的每个像素位置生成了一个256维的唯一的位置向量。我这简单写了个测试脚本:
感兴趣可以看下mmdet关于位置编码这部分实现逻辑(只是做了简单注释):
在得到图像特征向量x=[b,c,h,w]、masks[b,h,w]矩阵以及位置编码pos_embed[b,256,h,w]后,便可送入Transformer,关键是厘清encoder和decoder的QKV分别指啥,看代码:
其中encoder中q就是x,kv分别为None,query_pos代表位置编码,而query_key_padding_mask就是mask。decoder的q是全0的target,后续decoder会迭代更新q,而kv则 是memory,即encoder的输出;key_pos依旧是k的位置信息;query_embed即论文中Object query,可学习位置信息;key_padding_mask依然是mask。
先看下encoder初始化部分,内部循环调用了6次BaseTransformerLayer,因此只需讲解一层EncoderLayer即可。
在来看下BaseTransformerLayer的forward部分,该部分可以损失detr的核心部分了,因为本质上mmdet内部只是封装了pytorch现有的nn.MultiHeadAtten函数。所以,需要理解nn.MultiHeadAttn中两种mask参数的含义,限于篇幅原因,这里可参考nn.Transformer来理解这两个mask。 不过简单理解就是:attn_mask在detr中没用到,仅用key_padding_mask。attn_mask是为了遮挡未来文本信息用的,而图像可以看到全部的信息,因此不需要用attn_mask。
decoder部分和encoder流程类似,只是多了交叉注意力。
我这里简单贴下nn.MultiHeadAttn内部流程:
上述代码流程比较简单,就是首先计算Q中每个元素和K的相似度,要依次用两种mask来遮挡住,为后续的softmax做准备。以cross attn为例,attn_output_weights是计算了每个真实单词和原始句子每个单词的相似性权重,所以要用和src_key_padding_mask一样的memory_key_padding_mask在行的维度上进行遮挡,故二者pad_mask是一致的。
由于后续在detr上改进的论文对匈牙利算法以及loss计算改动不大,因此这部分代码就不讲解了。 感觉写的已经够乱了,哭脸。
到此这篇detr源码(detectron2源码解读)的文章就介绍到这了,更多相关内容请继续浏览下面的相关推荐文章,希望大家都能在编程的领域有一番成就!版权声明:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权、违法违规、事实不符,请将相关资料发送至xkadmin@xkablog.com进行投诉反馈,一经查实,立即处理!
转载请注明出处,原文链接:https://www.xkablog.com/rfx/20830.html