1743 字
9 分钟
矩阵分块思想实现高效多头注意力算法

最近阅读了Sebastian的《LLM From Scratch》,书中详细讲解了两种多头注意力机制:

  • 第一种方法是将多个注意力头独立计算,然后将它们各自输出的上下文向量拼接起来,如图所示

传统多头注意力机制的实现图

  • 第二种方法则采用了相反的思路:它通过矩阵分块的方式,在一个单头注意力模块中巧妙地实现了多头注意力机制。

首先需要明确三维张量的乘法运算原理: 假设A是一个2×3×22 \times 3 \times 2的三维张量,B是一个2×2×32 \times 2 \times 3的张量

pVK2P4f.png

实际上,这个计算过程可以理解为一种切片矩阵乘法的批量操作。我们将张量 AABB 沿 dim=0 方向切片,A的切片分别是A[0,:,:]A[1,:,:];B的切片分别是B[0,:,:]B[1,:,:],将两组切片的乘法运算结果在dim=0上扩展就是最终计算结果。即

output=torch.stack([A[0,:,:]@B[0,:,:],A[1,:,:]@B[1,:,:]],dim=0)

pVK2vGT


介绍了三维张量的乘法运算,接下来我们再来看看是如何用单头注意力实现多头注意力的。

pVK0hCj.md.png 为了更容易地理解这个过程,我们先做出如下假设:

  • 注意力头个数为num_heads=2
  • 单个批次的输入矩阵input的维度为context_length ×\times emb_size
  • context_length
  • 嵌入层输出维度emb_size
  • batch_size=1(本文不考虑batch_size这个维度
  • 每个注意力头的KiK_iQiQ_iViV_i向量维度为d_i_out
  • 整个(多头)KKQQVV向量维度为d_out
  • KKQQVV矩阵的初始形状为:context_length ×\times d_out

这个方法本质上是在一个单头注意力模块中运算,因此WqW_q ,WkW_k ,WvW_v 分别只有一个,此时虽然只有一个 KKQQVV 矩阵,但我们可以将它们视为多个注意力头的 KiK_iQiQ_iViV_i 拼接而成。最终的计算目标是为了得到上下文向量(把整个上下文向量看作是多个注意力头得到的上下问向量的拼接),如图所示(2个注意力头)

拼接

其中cvi=AttentionWeighti@Vicv_i=AttentionWeight_i @ V_i ,接下来应该计算AttentionWeightAttentionWeight

AttentionWeights=softmax(QKTdk)AttentionWeights=\mathbf{softmax}(\frac{QK^T}{\sqrt{d_k}})

我们知道在单头注意力中AttentionScore=Q@KTAttentionScore=Q@K^T,那么现在的问题是如何将QQKK视为多头注意力的QQKK的组合来计算得到多头注意力下的AttentionScoreAttentionScore,现在就可以利用上面所说的高维矩阵的乘法运算(切片乘法的批量运算)来完成这一目标了。

我们希望通过一系列变换,使得 Q[0,:,:]=Q1Q[0,:,:] = Q_1Q[1,:,:]=Q2Q[1,:,:] = Q_2KKVV 同理,在这种情况,AttentionScore=Q@K.transpose(1,2)AttentionScore=Q@K.transpose(1,2)(其中AttentionScoreAttentionScore的形状为num_heads ×\times context_length ×\times context_lengthQQ KK VV的形状为num_heads ×\times context_length ×\times d_i_out),从而可以完成每个注意力头的 AttentionScoreAttentionScore 计算。

书中介绍的变换是这样的:

Q=Q.view((context_length,num_heads,d_i_out))
Q=Q.transpose(0,1)
K=K.view((context_length,num_heads,d_i_out))
K=K.transpose(0,1)
V=V.view((context_length,num_heads,d_i_out))
V=V.transpose(0,1)

完成上述变换后:

  • QQ的形状就变为了num_heads ×\times context_length ×\times d_i_out,并且Q[0,:,:]=Q1Q[0,:,:]=Q_1 Q[1,:,:]=Q2Q[1,:,:]=Q_2

  • AttentionScore=Q@K.transpose(1,2)AttentionScore=Q@K.transpose(1,2)

  • AttentionWeightAttentionWeight的形状也是num_heads ×\times context_length ×\times context_length

然后就可以计算得到上下文向量了

contextVec=AttentionWeight@VcontextVec=AttentionWeight @ V

这里ContextVec的形状为num_heads ×\times context_length×\times d_i_out

然后对contextVec做如下变换,可以将其转换为如图所示的形状(即由多个注意力头输出的上下文向量在dim=1上拼接而成的二维张量,形状为context_length*d_out)

ContextVec=ContextVec.transpose(0,1)
#now contextVec.shape(context_length,num_heads,d_i_out)
contextVec=contextVec.view((context_length,d_out))

contextVec


解释一下这个变换的原理:

首先要说明一下张量数据的物理内存排列和逻辑排列

  • 张量数据的物理内存排列是根据最初定义的形式按行优先排列的。例如对于一个张量a=[[1,2,3],[4,5,6],[7,8,9]],其在内存中的排列是1,2,3,4,5,6,7,8,9

  • 张量的逻辑排列是在张量的物理内存排列的基础上使用stride()策略的视图。 print(a)输出的是a的逻辑视图。假设a.stride()=(m,n),则a[i,j]这个元素在物理内存的索引为i×m+j×ni \times m +j \times n,使用sride()策略,pytorch可以在张量的逻辑视图发生改变的情况下不需要频繁IO,提高了性能。

  1. view()操作不改变数据的逻辑排列:pytorch的view()操作不会改变数据的逻辑排列(数据在内存中一般按照行优先存储)的相对位置,只会改变解释数据的方式,对于一个形状为3*4的张量A:
#现在是将内存中这一串数据按照3*4的二维张量的形状来解释
A=torch.tensor(
[[1, 2 , 3 , 4],
[5, 6 , 7 , 8],
[9, 10, 11,12]]
)
#现在按照3*2*2的三维张量的形状来解释A
A=A.view((3,2,2))

print(A)的输出如下:

A=tensor([
[[ 1, 2],[ 3, 4]],
[[ 5, 6],[ 7, 8]],
[[ 9, 10],[11, 12]]
])

内存中数据是按行优先来排列,而view()操作前后数据的逻辑排列顺序是相同的。

  1. A=A.tranpose(0,1)能改变数据的逻辑排列

如果将内部长度为 2 的向量视为一个整体,张量 AA 可以重新组织为:

A=tensor(
[[a_11,a_12],
[a_21,a_22],
[a_31,a_32]]
)
#A.transpose(0,1)相当于对此二维张量做转置变换
#那么A_T是这样的
A_T=tensor([
[a_11 ,a_21 ,a_31],
[a_12 ,a_22 ,a_32]
])

然后还原内部长度为2的向量,A变为了

A_T=tensor([
[[1,2] ,[5,6] ,[9,10]],
[[3,4] ,[7,8] ,[11,12]]
])
#格式化一下就看得清楚了
#经过格式化之后A_T如下:
A_T=tensor([
[[ 1, 2],
[ 5, 6],
[ 9, 10]],
[[ 3, 4],
[ 7, 8],
[11, 12]]])

可以看到,A_T 中的数据逻辑排列与转置前相比已经发生了变化。

此外可以看到现在A_T的形状变为了2×\times3×\times2,而且

torch.cat([A[0,:,:],A[1,:,:]],dim=-1)=A

综上所述,这种方法通过张量的结构变换与分块运算,在一个注意力模块中实现了等效的多头注意力计算。


可以发现两种不同的多头注意力实现并没有减少参数的数量,那为什么第二种实现高效呢?

传统算法与高效算法对比:

  • K,Q,VK,Q,V相关计算和访存次数:传统算法每个注意力头都需要进行K,Q,VK,Q,V的访存和计算 ,总计算次数=3×\times 单个注意力头的计算次数;高效算法只有单个注意力头的计算次数,并且进行K,Q,VK,Q,V运算也不用多次访存,总计算次数=单个注意力头的计算次数
  • 权重矩阵的访存次数:传统算法中每个注意力头都要独立访存WK,WQ,WVW_K,W_Q,W_V矩阵,并且每个注意力的计算过程中都需要开辟内存来存储计算结果,最后还有一次拼接操作;高效算法只需要访存一次WK,WQ,WVW_K,W_Q,W_V矩阵,不需要拼接结果。

高效算法不仅减少了权重的访存频率,也避免了重复计算与拼接操作带来的开销,极大地提升了运行效率。类似的矩阵分块思想也可以应用于优化其他具有重复结构的张量运算过程。

记着先,免得忘记了

矩阵分块思想实现高效多头注意力算法
https://www.skywalkerch.top/posts/efficient-multi-head-attention-algorithm-using-matrix-block-decomposition/
作者
skywalkerch
发布于
2025-07-15
许可协议
CC BY-NC-SA 4.0