首页 > 代码库 > TensorFlow文本摘要生成 - 基于注意力的序列到序列模型

TensorFlow文本摘要生成 - 基于注意力的序列到序列模型

1 相关背景

维基百科对自动摘要生成的定义是, “使用计算机程序对一段文本进行处理, 生成一段长度被压缩的摘要, 并且这个摘要能保留原始文本的大部分重要信息”. 摘要生成算法主要分为抽取型(Extraction-based)和概括型(Abstraction-based)两类. 传统的摘要生成系统大部分都是抽取型的, 这类方法从给定的文章中, 抽取关键的句子或者短语, 并重新拼接成一小段摘要, 而不对原本的内容做创造性的修改. 这类抽取型算法工程上已经有很多开源的解决办法了, 例如Github上的项目sumy, pytextrank, textteaser等. 本文重点讲概括型摘要生成系统的算法思想和tensorflow实战, 算法思想源于A Neural Attention Model for Abstractive Sentence Summarization这篇论文. 本文希望帮助读者详细的解析算法的原理, 再结合github上相关的开源项目textsum讲解工程上的实际应用.本文由PPmoney大数据算法团队撰写,PPmoney是国内领先的互联网金融公司,旗下PPmoney理财总交易额超过700亿元。此外,若对TensorFlow的使用技巧和方法感兴趣,欢迎阅读本团队负责人所著的《TensorFlow实战》。

2 算法原理

下面对A Neural Attention Model for Abstractive Sentence Summarization这篇文章, 的算法原理进行讲解. 我们将这个模型简称为NAM. 主要分为模型训练(train)和生成摘要(decode)两部分讲解.

2.1 模型训练(train)

NAM这个模型是纯数据驱动, 我们喂给它的训练集数据是由一系列{正文: 摘要}对组成. 假设正文是x=[x1,...,xM]<script type="math/tex" id="MathJax-Element-3">\textbf{x} = [\textbf{x}_1,...,\textbf{x}_M]</script>, M<script type="math/tex" id="MathJax-Element-4">M</script>是正文词符的数量, 对应的摘要为y=[y1,...,yN]<script type="math/tex" id="MathJax-Element-5">\textbf{y} = [\textbf{y}_1,...,\textbf{y}_N]</script>, N<script type="math/tex" id="MathJax-Element-6">N</script>是摘要单词的数量.
对于给定的数据, 我们希望给定x<script type="math/tex" id="MathJax-Element-7">\textbf{x}</script>生成摘要为y<script type="math/tex" id="MathJax-Element-8">\textbf{y}</script>的概率最大, 即maxθlogp(y|x;θ)<script type="math/tex" id="MathJax-Element-9">\max_\theta{\log p(\textbf{y}|\textbf{x};\theta)}</script>, θ<script type="math/tex" id="MathJax-Element-10">\theta</script>是模型的参数. 但这个很难求解, 实际中我们用序列化的方式实例化这个目标, 原来的目标函数变为:

maxθi=0N?1logp(yi+1|x,yc;θ)
<script type="math/tex; mode=display" id="MathJax-Element-11">\max_\theta {\sum_{i=0}^{N-1}\log p(\textbf{y}_{i+1}|\textbf{x},\textbf{y}_c;\theta)}</script>
这里 yi+1<script type="math/tex" id="MathJax-Element-12">\textbf{y}_{i+1}</script>是要预测的下一个词, yc?y[i?C+1,...,i]<script type="math/tex" id="MathJax-Element-13">\textbf{y}_c\triangleq\textbf{y}_{[i-C+1,...,i]}</script>是已知的序列, C<script type="math/tex" id="MathJax-Element-14">C</script>是已知序列窗口的长度. 后面会提到, 这个窗口的位置也是注意力关注的位置, 在后面的训练过程中会根据学习到的权重调整不同位置注意力的概率大小. 这个窗口是随着i<script type="math/tex" id="MathJax-Element-15">i</script>的迭代来滑动的.
参数说明:
y<script type="math/tex" id="MathJax-Element-16">\textbf{y}</script>: 参考摘要所有单词向量组成的序列
x<script type="math/tex" id="MathJax-Element-17">\textbf{x}</script>: 正文的所以单词向量组成的序列
i<script type="math/tex" id="MathJax-Element-18">i</script>: 当前评估函数所对应的位置
yc<script type="math/tex" id="MathJax-Element-19">\textbf{y}_c</script>: 当前训练的窗口对应的局部摘要序列
yi+1<script type="math/tex" id="MathJax-Element-20">\textbf{y}_{i+1}</script>: 模型要预测的下一个单词

下面我们举一个例子来说明训练的过程:
技术分享
我们希望根据, 当前局部摘要序列yc<script type="math/tex" id="MathJax-Element-21">\textbf{y}_c</script>和全部的正文信息x<script type="math/tex" id="MathJax-Element-22">\textbf{x}</script>, 来预测下一个单词yi+1<script type="math/tex" id="MathJax-Element-23">\textbf{y}_{i+1}</script>. 我们希望模型预测下一个单词为yi+1<script type="math/tex" id="MathJax-Element-24">\textbf{y}_{i+1}</script>的概率最大, 并且希望所有单词都尽可能的预测准确, 在公式上表现为N?1i=0logp(yi+1|x,yc;θ)<script type="math/tex" id="MathJax-Element-25">\sum_{i=0}^{N-1}\log p(\textbf{y}_{i+1}|\textbf{x},\textbf{y}_c;\theta)</script>最大. 窗口C<script type="math/tex" id="MathJax-Element-26">C</script>会从摘要的起始位置滑动到终止位置, 当i<C<script type="math/tex" id="MathJax-Element-27">iyc<script type="math/tex" id="MathJax-Element-28">\textbf{y}_c</script>超出摘要的部分用起始符号<s>来补全.
我们感兴趣的分布p(yi+1|x,yc;θ)<script type="math/tex" id="MathJax-Element-29">p(\textbf{y}_{i+1}|\textbf{x},\textbf{y}_c;\theta)</script>是基于输入语句x<script type="math/tex" id="MathJax-Element-30">x</script>的条件语言模型. 这里我们直接将原始的分布, 参数化为一个神经网络. 这个神经网络既包括了一个神经概率语言模型(neural probabilistic language model), 也包括了一个编码器(这个编码器就是一个条件摘要模型).
通过包含编码器并且联合训练这两个组块, 我们根据当前yc<script type="math/tex" id="MathJax-Element-31">\textbf{y}_c</script>对x<script type="math/tex" id="MathJax-Element-32">\textbf{x}</script>的不同内容投入不同的关注度, 进而的到更好的结果. 模型结构如下图所示:
技术分享

  • 模型整体的网络结构图(具有一个额外的编码器单元):
    右侧分支: 仅根据当前的序列yc<script type="math/tex" id="MathJax-Element-33">\textbf{y}_c</script>预测下一个单词是yi+1<script type="math/tex" id="MathJax-Element-34">\textbf{y}_{i+1}</script>的概率, E<script type="math/tex" id="MathJax-Element-35">\textbf{E}</script>是词嵌入, yc<script type="math/tex" id="MathJax-Element-36">\tilde{\textbf{y}}‘_c</script> -> h<script type="math/tex" id="MathJax-Element-37">\textbf{h}</script>包括加权和激活函数的操作.
    左侧分支: 使用yc<script type="math/tex" id="MathJax-Element-38">\textbf{y}_c</script>和x<script type="math/tex" id="MathJax-Element-39">\textbf{x}</script>生成隐层的下一个输出, yc<script type="math/tex" id="MathJax-Element-40">\textbf{y}_c</script>会对encoder产生影响, 让encoder更多的关注x<script type="math/tex" id="MathJax-Element-41">\textbf{x}</script>中与yc<script type="math/tex" id="MathJax-Element-42">\textbf{y}_c</script>有关的内容.
    联合输出: 最终结合右侧的神经语言模型和左侧attention-based编码器的输出, 求下一个词是yi+1<script type="math/tex" id="MathJax-Element-43">\textbf{y}_{i+1}</script>的概率.

  • 基于注意力模型的编码器enc31的网络结构图:
    左侧分支: F<script type="math/tex" id="MathJax-Element-44">\textbf{F}</script>是词嵌入矩阵, x<script type="math/tex" id="MathJax-Element-45">\tilde{\textbf{x}}</script> -> xˉ<script type="math/tex" id="MathJax-Element-46">\bar{\textbf{x}}</script>是做了一下平滑处理.
    右侧分支: G<script type="math/tex" id="MathJax-Element-47">\textbf{G}</script>是词嵌入矩阵, 根据当前的yc<script type="math/tex" id="MathJax-Element-48">\textbf{y}‘_c</script>, 对x<script type="math/tex" id="MathJax-Element-49">\tilde{\textbf{x}}</script>的不同位置投入不同的注意力, 并形成一个加权向量.
    联合输出: 此时p<script type="math/tex" id="MathJax-Element-50">\textbf{p}</script>已经携带了注意力的信息, 用p<script type="math/tex" id="MathJax-Element-51">\textbf{p}</script>对平滑后的xˉ<script type="math/tex" id="MathJax-Element-52">\bar{\textbf{x}}</script>再做加权, 得到encoder的输出.
    下面两幅图分别是对整体结构和编码器结构的展开:
    技术分享
    技术分享

感兴趣的同学可以结合原文中的公式理解:
上图(a)中对应的公式:

p(yi+1|x,yc;θ)exp(Vh+Wenc(x,yc)),yc~=[Eyi?C+1,...,Eyi],h=tanh(Uyc~)
<script type="math/tex; mode=display" id="MathJax-Element-53"> p(\textbf{y}_{i+1}|\textbf{x},\textbf{y}_c;\theta) \propto \exp(\textbf{V}\textbf{h}+\textbf{W}enc(\textbf{x},\textbf{y}_c)),\\tilde{\textbf{y}_c} = [\textbf{E}\textbf{y}_{i-C+1},...,\textbf{E}\textbf{y}_{i}],\\textbf{h} = \tanh(\textbf{U}\tilde{\textbf{y}_c})</script>
参数是:
θ=(E,U,V,W),<script type="math/tex" id="MathJax-Element-54">\theta = (\textbf{E},\textbf{U},\textbf{V},\textbf{W}),</script>
E?D×V<script type="math/tex" id="MathJax-Element-55">\textbf{E}\in \mathbb{R}^{D\times V}</script>, 是一个词嵌入矩阵;
U?(CD)×H,V?V×H,W?V×H<script type="math/tex" id="MathJax-Element-56">\textbf{U}\in \mathbb{R}^{(CD)\times H}, \textbf{V}\in \mathbb{R}^{V\times H}, \textbf{W}\in \mathbb{R}^{V\times H}</script>, 是权重矩阵.
上图(b)中对应的公式:
enc3(x,yc)=pTxˉ,pexp(xPyc~),x=[Fx1,...,FxM],yc~=[Gyi?C+1,...,Gyi],?i,xˉi=q=i?Qi+Qxi/Q
<script type="math/tex; mode=display" id="MathJax-Element-57"> enc3(\textbf{x},\textbf{y}_c) = \textbf{p}^T\bar{ \textbf{x}},\\textbf{p}\propto\exp(\tilde{\textbf{x}}\textbf{P}\tilde{\textbf{y}_c}‘),\\tilde{\textbf{x}}=[\textbf{F}\textbf{x}_1,...,\textbf{F}\textbf{x}_M],\\tilde{\textbf{y}_c}‘=[\textbf{G}\textbf{y}_{i-C+1},...,\textbf{G}\textbf{y}_{i}],\\forall{i}, \bar{ \textbf{x}}_i = \sum_{q=i-Q}^{i+Q}\tilde{ \textbf{x}}_i/Q </script>
这里G?D×V<script type="math/tex" id="MathJax-Element-58">\textbf{G}\in\mathbb{R}^{D\times V}</script>是一个内容的嵌入, P?H×(CD)<script type="math/tex" id="MathJax-Element-59">\textbf{P}\in\mathbb{R}^{H\times (CD)}</script>是一个新的权重矩阵参数, Q<script type="math/tex" id="MathJax-Element-60">Q</script>是一个平滑窗口.
Mini-batch训练
这个模型是纯数据驱动的, 只要给它{正文: 摘要}训练集就能完成训练. 一旦我们已经定义了局部条件模型p(yi+1|x,yc;θ)<script type="math/tex" id="MathJax-Element-61">p(\textbf{y}_{i+1}|\textbf{x},\textbf{y}_c;\theta)</script>, 我们就能估计参数来最小化摘要集合的负对数似然函数. 假设训练集由J<script type="math/tex" id="MathJax-Element-62">J</script>个输入-摘要对组成(x(1),y(1)),...,(x(J),y(J))<script type="math/tex" id="MathJax-Element-63">(\textbf{x}^{(1)},\textbf{y}^{(1)}),...,(\textbf{x}^{(J)},\textbf{y}^{(J)})</script>. 负对数似然函数作用到摘要的每一个词, 即
NLL(θ)=?j=1Jlogp(y(j)|x(j);θ)=?j=1Ji=1N?1logp(y(j)i+1|x(j),yc;θ)
<script type="math/tex; mode=display" id="MathJax-Element-64">\textrm{NLL}(\theta)=-\sum_{j=1}^J\log p(\textbf{y}^{(j)}| \textbf{x}^{(j)}; \theta)=-\sum_{j=1}^J\sum_{i=1}^{N-1}\log p(\textbf{y}_{i+1}^{(j)}|\textbf{x}^{(j)},\textbf{y}_c;\theta)</script>
我们通过使用mini-batch和随机梯度下降最小化NLL.

2.2 Beam Search生成摘要(decode)

我们现在回到生成摘要的问题. 回顾前面, 我们的目标是找到:

y?=argmaxyi=0N?1logp(yi+1|x,yc;θ)
<script type="math/tex; mode=display" id="MathJax-Element-65">\textbf{y}^* = \arg\max_{\textbf{y}\in \mathcal{Y}}\sum_{i=0}^{N-1}\log p(\textbf{y}_{i+1}|\textbf{x},\textbf{y}_c;\theta)</script>
<script type="math/tex" id="MathJax-Element-66">\mathcal{Y}</script>是长度为N<script type="math/tex" id="MathJax-Element-67">N</script>的序列y<script type="math/tex" id="MathJax-Element-68">\textbf{y}</script>组成的集合, 如果字典中的单词数量是V<script type="math/tex" id="MathJax-Element-69">V</script>的话, 我们要生成的这个摘要就有VN<script type="math/tex" id="MathJax-Element-70">V^N</script>种可能性. 因为我们这里已经做了处理, 只根据前面的C<script type="math/tex" id="MathJax-Element-71">C</script>个已经预测出的单词yc<script type="math/tex" id="MathJax-Element-72">\textbf{y}_c</script>来预测下一个词yi+1<script type="math/tex" id="MathJax-Element-73">\textbf{y}_{i+1}</script>. 这样算法复杂度变成了O(NVC)<script type="math/tex" id="MathJax-Element-74">O(NV^C)</script>. 但是即使是这样, 这个算法也太复杂了.
使用维特比译码需要O(NVC)<script type="math/tex" id="MathJax-Element-75">O(NV^C)</script>.复杂度获得精确的解. 然而在实际中V<script type="math/tex" id="MathJax-Element-76">V</script>太大使得问题难解. 一个替代方法是使用贪婪解来近似获得argmax, 只保证每次前进的一小步是概率最大的.
在精确解和贪婪解方法之间取一个折中, 就是beam-search束搜索解码器(Algorithm1), 它在保持全量字典V<script type="math/tex" id="MathJax-Element-77">V</script>的同时, 在输出摘要的每一个位置上将自己限制在K<script type="math/tex" id="MathJax-Element-78">K</script>个潜在的假设内. 这种beam-search方法在神经机器翻译模型NMT也很常用. Beam search算法展示如下:
技术分享
参数说明:
N<script type="math/tex" id="MathJax-Element-79">N</script>: 摘要的长度
K<script type="math/tex" id="MathJax-Element-80">K</script>: beam的尺寸
V<script type="math/tex" id="MathJax-Element-81">V</script>: 字典里所有单词的数量
C<script type="math/tex" id="MathJax-Element-82">C</script>: 关注的词序列的长度

Beam search案例

下面举一个简单的例子来说明beam search算法的运行过程. 在这个例子里, 摘要长度N=4<script type="math/tex" id="MathJax-Element-83">N=4</script>, beam的大小K=6<script type="math/tex" id="MathJax-Element-84">K=6</script>, 注意力窗口大小C=2<script type="math/tex" id="MathJax-Element-85">C=2</script>, 模型最理想的结果是‘i am a chinese’. Beamsearch的每一次迭代都从字典V<script type="math/tex" id="MathJax-Element-86">V</script>里找K<script type="math/tex" id="MathJax-Element-87">K</script>个最大的可能.
技术分享
Step1: 预测前C<script type="math/tex" id="MathJax-Element-88">C</script>个词的时候窗口溢出的部分需要进行padding操作, 预测第1个词的时候我们选出K<script type="math/tex" id="MathJax-Element-89">K</script>个词符.
技术分享
Step2: 预测第2个词的时候, 我们选出新的K个词符, 对应K条备选路径. 前一阶段概率低的路径和词符, 被抛弃掉.
技术分享
Step3: 重复前面的过程.
技术分享
Step4: 每次beam search不一定能选出不同的K个词, 但是每次beam search都找到最优的前K个路径, 路径可以有重叠.
技术分享
Step5: 迭代N次, 最终选出可能性最大的一条词序列路径
技术分享
下面是对Beam Search算法的详细分析, 对原文的Algorithm 1逐条进行解释.

Beam Search算法分析

  1. π[0]<script type="math/tex" id="MathJax-Element-90">\pi[0]</script>是可以用规定好的起始符号<s>来初始化. 在训练和生成摘要时, 窗口Q<script type="math/tex" id="MathJax-Element-91">Q</script>和C<script type="math/tex" id="MathJax-Element-92">C</script>沿着文本滑动如果超出范围, 用起始符号<s>做padding.
  2. 如果模型是abstraction-based, 输出y<script type="math/tex" id="MathJax-Element-93">\textbf{y}</script>的备选集合是整个字典, 如果希望摘要的单词全部从原文中抽取, 那么词典由输入正文x<script type="math/tex" id="MathJax-Element-94">\textbf{x}</script>的所有单词构成.
  3. 我们会设定一个最大输出长度N<script type="math/tex" id="MathJax-Element-95">N</script>, 算法会进行N<script type="math/tex" id="MathJax-Element-96">N</script>轮迭代.
    1. 现已有K<script type="math/tex" id="MathJax-Element-97">K</script>个假设, 每一个假设都对应一条路径; 对每一个假设, 我们从字典S<script type="math/tex" id="MathJax-Element-98">S</script>(有V<script type="math/tex" id="MathJax-Element-99">V</script>个单词)中选出K<script type="math/tex" id="MathJax-Element-100">K</script>个单词作为备选.
    2. 在字典中寻找, 搜索其他单词, 如果计算的到的state值比当前集合中的任意一个大, 就把它保留下来.
    3. 当每一个假设都遍历完整个字典S<script type="math/tex" id="MathJax-Element-101">S</script>, 就会产生K×K<script type="math/tex" id="MathJax-Element-102">K\times K</script>条路径, 我们在这些路径中选择概率最大的K<script type="math/tex" id="MathJax-Element-103">K</script>个路径作为下一次迭代的基础.(每一条路径都保留了之前i?1<script type="math/tex" id="MathJax-Element-104">i-1</script>个节点对应的单词)
  4. N<script type="math/tex" id="MathJax-Element-105">N</script>次迭代进行完后, 我们只剩下了K<script type="math/tex" id="MathJax-Element-106">K</script>条路径, 最后在从这其中选出1条概率最大的即可.
  5. 路径所经历的所有节点即为摘要的单词. 如果这中间遇到了停止符<e>, 摘要就是从<s><e>, 如果没有<e>出现, 摘要的最大长度就是N<script type="math/tex" id="MathJax-Element-107">N</script>.

Beam Search的运算复杂度从O(NVC)<script type="math/tex" id="MathJax-Element-108">O(NV^C)</script>变成了O(KNV)<script type="math/tex" id="MathJax-Element-109">O(KNV)</script>, 因为V>>N<script type="math/tex" id="MathJax-Element-110">V>>N</script>和K<script type="math/tex" id="MathJax-Element-111">K</script>, 加速效果非常显著. 束搜索依据已经计算好的路径以及当前的V<script type="math/tex" id="MathJax-Element-112">V</script>个备选值, 计算出最优的K<script type="math/tex" id="MathJax-Element-113">K</script>的值. 最新的K<script type="math/tex" id="MathJax-Element-114">K</script>个最优值都保留着相应路径上之前的所有的节点.

3 TensorFlow程序实战

NAM模型的程序最早是由facebook开源的torch版本的程序. 最近谷歌开源了TensorFlow版本的摘要生成程序textsum, Github上的项目. textsum的核心模型就是基于注意力的seq2seq(sequence-to-sequence)模型, textsum使用了LSTM和深度双向RNN.
Github上的textsum首页给出了此项目在Bazel环境下的运行方式. 如果你不想通过Bazel运行, 你可以直接在seq2seq_attention.py中设定运行参数. 设定完参数后, 直接运行python seq2seq_attention.py即可. 参数设定如下图所示:
技术分享
除了上述项目运行时所需的必要参数, 模型参数也在seq2seq_attention.py中设定, 如下图所示, 包括学习率, 最小学习率(学习率会衰减但不会低于最小学习率), batch size, train模式encoder的RNN层数, 输入正文词汇数上限, 输出摘要词汇数上限, 最小长度限制, 隐层节点数, word embedding维度, 梯度截取比例, 每一个batch随机分类采样的数量.
技术分享
git项目textsum给的toy数据集太小, vocab也几乎不可用(一些常见的单词都没有覆盖到). 如果希望获得好的效果, 需要自己整理可用的数据集.
主要文件说明:
- seq2seq_attention.py: 主程序, 选择程序的运行模式, 设定参数, 建立模型, 启动tensorflow
- seq2seq_attention_model.py: 建立attention-based seq2seq model, 包括算法的encoder, decoder和attention模块, 都在Seq2SeqAttentionModel中完成.
- seq2seq_attention_decode.py: 读取数据, 调用beam_search解码
beam_search.py: beam search算法的核心程序

textsum程序解析

Google开源的textsum项目的具体算法是基于Hinton 2014年的Grammar as a Foreign Language这篇论文, 下面给出textsum工程中attention-based seq2seq模型的整体结构图, 图中所使用的名字与程序中的变量名一致, Seq2SeqAttentionModel是一个类, 定义在seq2seq_attention_model.py中; attention_decoder是一个函数, 定义在/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py中.
为了方便理解, 简单解释一下图中出现的符号,
技术分享
第一个符号表示从x1,x2到y的线性变换, 红色变量是训练过程要学习出来的.
技术分享
attention机制比较复杂也比较重要, 我们对这部分细化一下来看. attention decoder结构图如下:
技术分享
下图是对attention模块的细化:
技术分享
符号说明:
技术分享

为什么attention这个模块会起到效果呢? 因为attention模块会根据decoder当前时刻的LSTM单元的状态, 来调整对attention_states(encoder输出)的注意力. Attention_states不同位置获得的关注不一样. 这样我们就更大程度地, 关注了原文中, 对当前输出更为有用的信息, 输出结果也就更准确了. Attention模块输出结果和decoder模块原本的输出联合起来, 得到最终的输出结果.

相关链接:

  • 算法原理部分的论文:
    https://arxiv.org/abs/1509.00685
  • textsum开源程序链接:
    https://github.com/tensorflow/models/tree/master/textsum
  • textsum中使用的算法原理论文:
    https://arxiv.org/abs/1412.7449

  1. Rush在他的论文中提到了Bag-of-Words, Convolutional和Attention-Based三种编码器, 这里重在强调第三种. ?
<script type="text/javascript"> $(function () { $(‘pre.prettyprint code‘).each(function () { var lines = $(this).text().split(‘\n‘).length; var $numbering = $(‘
    ‘).addClass(‘pre-numbering‘).hide(); $(this).addClass(‘has-numbering‘).parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($(‘
  • ‘).text(i)); }; $numbering.fadeIn(1700); }); }); </script>

    TensorFlow文本摘要生成 - 基于注意力的序列到序列模型