最近阅读了Sebastian的《LLM From Scratch》,书中详细讲解了两种多头注意力机制:
- 第一种方法是将多个注意力头独立计算,然后将它们各自输出的上下文向量拼接起来,如图所示
- 第二种方法则采用了相反的思路:它通过矩阵分块的方式,在一个单头注意力模块中巧妙地实现了多头注意力机制。
首先需要明确三维张量的乘法运算原理: 假设A是一个的三维张量,B是一个的张量
实际上,这个计算过程可以理解为一种切片矩阵乘法的批量操作。我们将张量 和 沿 dim=0 方向切片,A的切片分别是A[0,:,:]
和A[1,:,:]
;B的切片分别是B[0,:,:]
和B[1,:,:]
,将两组切片的乘法运算结果在dim=0
上扩展就是最终计算结果。即
output=torch.stack([A[0,:,:]@B[0,:,:],A[1,:,:]@B[1,:,:]],dim=0)
介绍了三维张量的乘法运算,接下来我们再来看看是如何用单头注意力实现多头注意力的。
为了更容易地理解这个过程,我们先做出如下假设:
- 注意力头个数为num_heads=2
- 单个批次的输入矩阵input的维度为context_length emb_size
- context_length
- 嵌入层输出维度emb_size
- batch_size=1(本文不考虑batch_size这个维度)
- 每个注意力头的,,向量维度为d_i_out
- 整个(多头),,向量维度为d_out
- ,,矩阵的初始形状为:context_length d_out
这个方法本质上是在一个单头注意力模块中运算,因此 , , 分别只有一个,此时虽然只有一个 、、 矩阵,但我们可以将它们视为多个注意力头的 、、 拼接而成。最终的计算目标是为了得到上下文向量(把整个上下文向量看作是多个注意力头得到的上下问向量的拼接),如图所示(2个注意力头)
其中 ,接下来应该计算,
我们知道在单头注意力中,那么现在的问题是如何将和视为多头注意力的和的组合来计算得到多头注意力下的,现在就可以利用上面所说的高维矩阵的乘法运算(切片乘法的批量运算)来完成这一目标了。
我们希望通过一系列变换,使得 、, 和 同理,在这种情况,(其中的形状为num_heads context_length context_length, 的形状为num_heads context_length d_i_out),从而可以完成每个注意力头的 计算。
书中介绍的变换是这样的:
Q=Q.view((context_length,num_heads,d_i_out))Q=Q.transpose(0,1)
K=K.view((context_length,num_heads,d_i_out))K=K.transpose(0,1)
V=V.view((context_length,num_heads,d_i_out))V=V.transpose(0,1)
完成上述变换后:
-
的形状就变为了num_heads context_length d_i_out,并且
-
-
的形状也是num_heads context_length context_length
然后就可以计算得到上下文向量了
这里ContextVec的形状为num_heads context_length d_i_out
然后对contextVec做如下变换,可以将其转换为如图所示的形状(即由多个注意力头输出的上下文向量在dim=1上拼接而成的二维张量,形状为context_length*d_out)
ContextVec=ContextVec.transpose(0,1)#now contextVec.shape(context_length,num_heads,d_i_out)contextVec=contextVec.view((context_length,d_out))
解释一下这个变换的原理:
首先要说明一下张量数据的物理内存排列和逻辑排列
张量数据的物理内存排列是根据最初定义的形式按行优先排列的。例如对于一个张量
a=[[1,2,3],[4,5,6],[7,8,9]]
,其在内存中的排列是1,2,3,4,5,6,7,8,9
。张量的逻辑排列是在张量的物理内存排列的基础上使用stride()策略的视图。 print(a)输出的是a的逻辑视图。假设a.stride()=(m,n),则a[i,j]这个元素在物理内存的索引为,使用sride()策略,pytorch可以在张量的逻辑视图发生改变的情况下不需要频繁IO,提高了性能。
- view()操作不改变数据的逻辑排列:pytorch的view()操作不会改变数据的逻辑排列(数据在内存中一般按照行优先存储)的相对位置,只会改变解释数据的方式,对于一个形状为3*4的张量A:
#现在是将内存中这一串数据按照3*4的二维张量的形状来解释A=torch.tensor( [[1, 2 , 3 , 4], [5, 6 , 7 , 8], [9, 10, 11,12]])
#现在按照3*2*2的三维张量的形状来解释AA=A.view((3,2,2))
print(A)的输出如下:
A=tensor([[[ 1, 2],[ 3, 4]],[[ 5, 6],[ 7, 8]],[[ 9, 10],[11, 12]]])
内存中数据是按行优先来排列,而view()操作前后数据的逻辑排列顺序是相同的。
A=A.tranpose(0,1)
能改变数据的逻辑排列
如果将内部长度为 2 的向量视为一个整体,张量 可以重新组织为:
A=tensor([[a_11,a_12],[a_21,a_22],[a_31,a_32]])#A.transpose(0,1)相当于对此二维张量做转置变换#那么A_T是这样的A_T=tensor([[a_11 ,a_21 ,a_31],[a_12 ,a_22 ,a_32]])
然后还原内部长度为2的向量,A变为了
A_T=tensor([[[1,2] ,[5,6] ,[9,10]],[[3,4] ,[7,8] ,[11,12]]])#格式化一下就看得清楚了#经过格式化之后A_T如下:A_T=tensor([ [[ 1, 2], [ 5, 6], [ 9, 10]],
[[ 3, 4], [ 7, 8], [11, 12]]])
可以看到,A_T 中的数据逻辑排列与转置前相比已经发生了变化。
此外可以看到现在A_T的形状变为了232,而且
torch.cat([A[0,:,:],A[1,:,:]],dim=-1)=A
综上所述,这种方法通过张量的结构变换与分块运算,在一个注意力模块中实现了等效的多头注意力计算。
可以发现两种不同的多头注意力实现并没有减少参数的数量,那为什么第二种实现高效呢?
传统算法与高效算法对比:
- 相关计算和访存次数:传统算法每个注意力头都需要进行的访存和计算 ,总计算次数=3 单个注意力头的计算次数;高效算法只有单个注意力头的计算次数,并且进行运算也不用多次访存,总计算次数=单个注意力头的计算次数
- 权重矩阵的访存次数:传统算法中每个注意力头都要独立访存矩阵,并且每个注意力的计算过程中都需要开辟内存来存储计算结果,最后还有一次拼接操作;高效算法只需要访存一次矩阵,不需要拼接结果。
高效算法不仅减少了权重的访存频率,也避免了重复计算与拼接操作带来的开销,极大地提升了运行效率。类似的矩阵分块思想也可以应用于优化其他具有重复结构的张量运算过程。
记着先,免得忘记了