当前位置: 首页>后端>正文

Loss Function可视化 transformer可视化

最近需要对Transformer网络的中间层进行可视化,便于分析网络,在此记录一些常用到的概念。

 

常用到的方法主要是Attention Rollout和Attention Flow

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_可视化,第1张

,这两种方法都对网络中每一层的token attentions进行递归计算,主要的不同在于假设低层的attention weights如何影响到高层的信息流,以及是否计算token attentions之间的相关性。

 

 为了计算信息从输入层传播到高层的嵌入方式,关键是考虑模型中残差连接以及attention权重。在一个Transformer块中,self-attention和前向网络都被残差连接包裹,也就是将这些模块的输入添加到输出中。当仅使用attention weights来近似Transformers中的信息流时,就忽略了残差连接。但是残差连接连接了不同层的对应位置,所以在计算attention rollout和attention flow时,用额外的权重来表示残差连接。给定一个有残差连接的attention模块,将层

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_Loss Function可视化_02,第2张

的值计算为

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_Loss Function可视化_03,第3张

,其中

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_可视化_04,第4张

是attention矩阵,因此有

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_可视化_05,第5张

。所以,为了解释残差连接,给attention矩阵增加一个单位矩阵,并且对新矩阵的权重重新规范化。最后生成由残差连接更新的原始矩阵

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_Loss Function可视化_06,第6张


 

此外,分析单个attention head需要考虑输入在通过Transformer块中位置级的前向网络后,各个heads之间混合的信息。使用Attention Rollout和Attention Flow能够单独分析每个attention head中的信息,但是为了简便,一般在所有的attention heads上平均每一层的attention来进行分析。

 

Attention rollout 

给定一个L层的Transformer,希望计算从

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_Loss Function可视化_07,第7张

层所有位置到

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_结点_08,第8张

层(其中

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_权重_09,第9张

)所有位置的attention。在attention图中,从

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_Loss Function可视化_07,第7张

层位置k的结点v到

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_结点_08,第8张

层中位置m的结点u的路径是一系列连接这两个结点的边。如果将每条边的权重视为两个节点之间传输信息的比例,那么可以将路径中所有边的权重相乘来计算结点v中有多少信息通过该路径传播到了结点u。为了计算从

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_Loss Function可视化_07,第7张

层到

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_结点_08,第8张

层的attentions,在下面所有层中递归地乘以attention矩阵。

Loss Function可视化 transformer可视化,第14张j & \\ A(l_{i})& if i=j \end{matrix}\right." title="\widetilde{A}(l_{i}) = \left\{\begin{matrix} A(l_{i})\widetilde{A}(l_{i-1}) & if i>j & \\ A(l_{i})& if i=j \end{matrix}\right." style="width: 259px; visibility: visible;" data-type="block">

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_可视化_15,第15张

在计算公式中

Loss Function可视化 transformer可视化,Loss Function可视化 transformer可视化_结点_16,第16张

是attention rollout,A是原始的attention,相乘操作是矩阵乘法。在该公式下,设置j=0来计算输入的attention值。

 

 

 

 

 


https://www.xamrdz.com/backend/3h91942332.html

相关文章: