Transformer 时序模型

Transformer 模型

Transformer 来自于 Google Brain 团队 2017 年的文章 Attention is all you need。正如论文的题目所述,整个网络结构完全由注意力机制组成,由于没有使用 RNN 和 CNN,避免了无法并行计算和长距离依赖等问题,用更少的计算资源,取得了更好的结果,刷新了多项机器翻译任务的记录。

如图 7 从整体架构上看,transformer 仍属于编码器-解码器架构,通过编码器(Encoder)将输入序列转换成内部表示,再通过不同解码器(Decoder)实现不同的预测功能。

图 7.Transformer 架构

为什么 Attention is all you need?

作为 Transformer 论文的最大创新,Transformer 模型仅仅使用注意力机制不仅完成了以前需要 RNN 才能做到的工作,而且做的更快更好,下面我们就来看看 Transformer 是如何做到的。

自注意力

Transformer 模型的首要工作就是使用编码器生成序列编码,前面我们介绍了注意力机制具备聚合序列元素信息的能力,在 transformer 的编码器中就是使用这种能力来生成序列编码,由于在编码器中注意力机制的注意对象是输入序列自身,因此被称为自注意力(self attention)。时序问题(特别是自然语言处理问题)中的序列元素表示的含义通常不止该单个元素的的字面意义,而是与整个序列上下文有关系,因此在编码过程中需要考虑整个序列来决定其中每个元素的意义。自注意力机制中将每个元素都作为关注目标进行注意力计算,因此每个元素都对彼此在序列上下文中进行解释,很好的体现了这种通过全局确定局部的思想。 图 8 来自 Jay Alammar 的著名博文 The Illustrated Transformer,它可视化的展示了在机器翻译任务下自注意力机制在对输入元素"it"的解释过程中,"the"和"animal"都发挥了比较大的权重。

图 8.编码器自注意力可视化

注意力遮罩 Attention mask

由于注意力机制可以直接看到所有的元素,因此需要一种手段来防止注意力机制处理"不应该被看到的元素",这是指在模型训练阶段不能让解码器的自注意力机制看到训练数据中当前时间点之后的正确预测值,否则模型就会利用标准答案"作弊",如图 9 所示。

图 9.注意力遮罩

Scaled Dot-Product Attention (SDPA)

Transformer 对标准的注意力计算做了一个小小调整:加入特征缩放(feature scaling)。这样做主要是为了防止 softmax 运算将值较大的 key 过度放大,导致其他 key 的信息很难加入到计算结果中。 特征缩放体现在对 Q 和 K 计算点积 QKT 以后,增加了一步除以 √(d_k ) 运算。

图 10 是上式的图像化表示,其中 Scale 就是特征缩放的操作。

图 10.SDPA

位置编码 Positional encoding

与 RNN 和 CNN 不同,在注意力机制中没有先后顺序的概念(如第一个元素,第二个元素等),输入序列的所有元素都以没有特殊顺序或位置的方式输入网络,模型不知道元素的先后顺序。因此,需要将与位置相关的信号添加到每个元素中,以帮助模型理解序列中元素的排列顺序。最简单直接的位置编码方式是将每个元素的序号加入元素编码后再输入模型,这样做是否可行呢?考虑到序列的长度可以是任意长度,只讨论元素的绝对位置是不全面的(同一个词,在由 3 个词组成的句子中的第三个位置和 30 个词组成的句子中的第三个位置所表达的意思很可能是不一样的)。因此 Transformer 使用了基于周期函数(sin/cos 函数)的位置编码方法。Transformer 的位置编码 PE 可以表示为

其中 pos 表示位置,i 表示元素编码的维度,dmodel 表示模型的维度,这种位置编码有如下优点:

  • 利用 sin/cos 函数的周期性它能够进行任意长度序列的位置编码。
  • 由于 sin(i+x) 函数可以展开为 sin(i)和 cos(i) 的线性表达式,使得 PE(i+x) 的计算可以展开为 PEi 的线性表达式,因此计算相对位置的效率比较高。
  • 使用多个不同频率来保证不会由于周期性导致不同位置的编码相同。
  • sin/cos 函数的值总是在 -1 到 1 之间,这有利于神经网络的学习。

计算产生的位置编码是一个与元素具有相同维度的向量,使用相加的方式将位置信息叠加进元素中,如图 11 所示。在 Transformer 论文中没有解释为什么使用相加方式,直观感觉相加操作会造成对元素向量的污染,而串联(concatenate)就不会有这种问题。实验显示在高维中随机选择的向量几乎总是近似正交的,也就是说元素向量和位置编码向量是相互独立的。因此尽管进行了矢量相加,但两个向量仍可以通过一些变换而彼此独立地进行操作。也是正因为这种向量正交关系,串联并不会比相加表现得更好,但会大大增加学习参数方面的成本。

图 11.位置编码与元素编码进行相加操作

多头注意力 Multiple Headed Attention, MHA

Transformer 仅仅使用注意力机制处理输入生成序列编码,由于注意力机制本质上只是对输入进行加权平均运算,没有引入新参数也没有使用非线性运算,这导致复杂特征提取能力不足,为了解决这个问题论文提出了多头注意力的方法。和卷积神经网络通过使用多个卷积核来发掘不同特征的思路类似,多头注意力也是通过多次随机初始化过程来提取不同特征。 图 12 中通过三次随机初始化分别得到了三种特征:红色表示动作,绿色表做动作施加者,蓝色表示动作承受着,可以看到在对"踢"进行了三次自注意力运算,分别对应三种特征。在对于动作信息的自注意力运算中,"我"和"球"的权值(灰色细线表示)比"踢"的权值(红色粗线)要小很多;同样,对动作施加者的自注意力运算中,"我"(绿色粗线)则是主要贡献者。在将三次自注意力运算的结果相加后,得到的新的"踢"的编码中就包含了三种特征的信息。理论上随机初始化测次数越多就越有可能发现有效的特征,不过随之增长的是训练参数的增加,这意味着训练难度的提高,因此需要平衡,在 Transformer 模型中这个值是 8。

图 12.多头注意力的作用

具体实现来说多头注意力是对同一个元素进行多次注意力运算,每次注意力计算之前分别使用随机生成的参数 W^Q,W^K,W^V 通过矩阵相乘来初始化 Q,K,V,

其中:

  • 对于编码器多头自注意力 MHSA,Q,K,V 都是输入元素编码 xi
  • 对于解码器多头自注意力 MHSA,Q,K,V 都是已生成的输出元素编码 yi
  • 对于编码器-解码器多头注意力 MHA,Q 是输出元素编码 yi, K,V, 是 context vector 中的元素 ci

在分别完成 i 次注意力运算之后,再将运算结果进行合并。合并的方法是首先对 i 次结果进行串联(concatenate),由于 Transformer 模型要求输出和输入具有相同的维度来实现多个编码层串联,因此再通过对串联结果和 WO进行矩阵相乘得到和输入同样维度的结果。

Transformer 全貌

在介绍了 Transformer 的主要组成部分之后,我们再来完整看一下 Transformer 模型。整体上来看,Transformer 模型属于编码器-解码器架构,由于解码器需要根据序列编码和上一步的解码器输出来产生下一个输出,因此 Transformer 属于自回归模型(autoregressive model)。

图 13.Transformer 全貌

编码器

Transformer 的编码器负责处理分析提取输入序列的特征并生成序列编码。它由若干个编码层构成,所有编码层的结构完全一样,这些编码层相互串联在一起,编码器的输入首先进入第一个编码层,结算结果输入第二层,依次经过所有编码层后作为编码器的输出。 每个编码层由多头自注意力单元和按位前馈网络两部分组成。输入首先进入自注意力计算单元,再将计算结果输入按位前馈网络,这里的按位的含义是指每个位置的元素各自输入前馈网络里进行计算,前馈网络的结构为 2 个串联的全连接层,中间层维度较大(是元素编码维度的 4 倍),最后一层的维度和元素编码的维度相同。这个设计的目的和多头注意力的设计类似,还是由于注意力机制在特征合成能力的不足,需要借助全连接网络的非线性计算来增加复杂特征合成的能力。

解码器

解码器负责根据序列编码和上一步的解码器输出预测下一步的输出。它同样由多个结构相同的解码层串联而成,每个解码层由三部分组成,按照处理的先后顺序分别是解码自注意力单元 MHSA,编码器-解码器注意力单元 HMA 和按位前馈网络。作为解码器的核心,编码器-解码器 HMA 接收两个输入 Q,K,第一个输入Q 由解码器上一步输出经过带遮罩的解码器 HMSA 处理后得到,第二个输入 K 是编码器的输出:序列编码。编码器-解码器 HMA 的输出在经过按位前馈网络合成复杂特征。经过多个解码层处理后在通过全连接运算映射到目标词典空间,最后通过 softmax 选择可能性最大的元素作为输出。

图 14 展示了 Transformer 在进行英-中翻译任务中的主要工作流程:

  1. 输入元素进行位置编码,位置编码与输入元素编码按位相加
  2. 在编码层
    1. 首先进行输入元素自注意力(多头注意力)计算,
    2. 再将结果输入按位前馈网络
  3. 重复多次编码层结算,结束编码阶段,得到 context vector
  4. 开始解码阶段,首先对输出元素进行位置编码(第一个输出为开始标记 SOS), 输入元素与其位置编码按位相加
  5. 在解码层
    1. 首先进行输出元素(当前已输出)的多头自注意力计算
    2. 进行解码器-编码器多头注意力计算
    3. 对 5.2 结果按位前馈网络
  6. 重复多次解码层计算
  7. 通过全连接网络转化为目标词典维度向量,使用 softmax 确定输出元素(可能性最大)
  8. 将当前输出元素输入 4 开始下一个输出元素的计算,直到输出为结束标记符 EOS

图 14.使用 Transformer 进行机器翻译的流程

总结来说,注意力机制是 transformer 的核心,它具有计算效率高,可并行,容易训练等优势,但是同时也带了一些新问题:比如无序和特征合成能力下降。Transformer 针对这些新问题分别提出了解决方案,如使用位置编码生成位置信息,使用多头注意力和按位前馈网络增强特征合成能力。

Transformer 优化技巧

由于 Transformer 属于比较复杂的深度模型,因此要通过使用一些优化技巧才能进行训练。Transformer 中运用到的优化技术比较多,我们选择其中比较重要或者是有针对性的来进行简单介绍

1. 残差链接 Residual connection

网络越深,表达能力越强,所以在需要表达复杂特征(如 NLP,图像)的场景中使用的神经网络正在变得越来越深,但是深层网络带来了两个问题:

  1. 梯度弥散、爆炸,使得模型难以训练
  2. 网络退化(degradation),当网络深度到达一定程度后,性能就会随着深度的增加而下降。

图 15. 残差连接

残差链接用一个简单的办法巧妙的解决了这两个问题,就是将两个不相邻网络层直接连接(短接)。这样梯度 gradient 可以跨越中间层直接传递,避免经过中间层时梯度被多次缩放导致梯度弥散(爆炸)的问题;另一方面,实验证明当使用 RELU 作为激活函数时,残差连接也能有效防止网络退化。Transformer 中的每一个编码层(解码层)都使用了残差连接来分别短接多头注意力和按位前馈网络,这样做一来解决了梯度传递的问题,同时还能帮助位置信息顺利传递到高层去

2. 层归一化 Layer normalization

归一化是机器学习中常用的一种数据预处理方法,为了更有效的运行机器学习算法,需要将原始数据"白化"(Whitening),也就是在统计学中常常提到的使数据"独立,同分布"。 目前在深度学习中最常用的是批归一化(Batch   Normalization),它对不同训练数据的同一维度进行归一化,这种方法可以有效缓解深度模型训练中的梯度爆炸、弥散的问题。而在 transformer 采用了相对冷门的层归一化,主要原因是批归一化很难应用在训练数据长度不同的时序任务上,而这正是层归一化的优势所在,由于它是作用在单个训练数据的不同维度上,因此它能够在一条数据上进行归一化。

3. 标签平滑归一化 Label smoothing regularization

通常我们使用交叉熵来计算预测误差时使用独热(one-hot)编码表示真实值,梯度下降算法为了减小误差会尽量使预测结果接近独热编码,也就是说,网络会驱使自身往正确标签和错误标签差值大的方向学习,在训练数据不足以表征所有的样本特征的情况下,预测结果的置信度过高会导致网络过拟合。 标签平滑归一化通过"软化"传统的独热编码,使得训练时能够有效抑制过拟合现象。它的实现非常简单,通过一个超参数 ?∈(0,1) 将原来的 0,1 分布变成 ?,1-? 分布(对于二值分类问题),这样就缩短了真假值之间的距离,最终起到抑制过拟合的效果。

4. 学习率热身 Learning rate warm up

训练初期由于离目标较远,一般需要选择大的学习率,但如果训练数据集具有高度的差异性则使用过大的学习率则可能导致不稳定性。这是由于如果初始化后的数据恰好只包含一部分特征,则模型的初始训练可能会严重偏向于这些特征,这会增加模型学习其他特征的难度。 所以可以增加一个学习率热身阶段,在开始的时候先使用一个较小的学习率,然后当训练过程稳定的时候再把学习率调回去。在预热期间,学习率呈线性增加。如果目标学习率是 p ,预热期是 n,则第一批迭代将 p/n 用作学习率;第二个使用 2*p/n ,依此类推:迭代 i 使用 i*p/n,直到我们在迭代 n 次后达到学习率 p。

Transformer 的改进和发展

Transformer 取得巨大成功引起关注的同时,学术和产业界都在尝试在实现和理论层面对他进行改进和增强

应用(Transformer-XL)

展开阅读全文

本文系作者在时代Java发表,未经许可,不得转载。

如有侵权,请联系nowjava@qq.com删除。

编辑于

关注时代Java

关注时代Java