首页 > 代码库 > DL4NLP——神经网络(二)循环神经网络:BPTT算法步骤整理;梯度消失与梯度爆炸

DL4NLP——神经网络(二)循环神经网络:BPTT算法步骤整理;梯度消失与梯度爆炸

      网上有很多Simple RNN的BPTT算法推导。下面用自己的记号整理一下。

      我之前有个习惯是用下标表示样本序号,这里不能再这样表示了,因为下标需要用做表示时刻。

      典型的Simple RNN结构如下:

技术分享

图片来源:[3]

      约定一下记号:

      输入序列 $\textbf x_{(1:T)} =(\textbf x_1,\textbf x_2,...,\textbf x_T)$,每个时刻的值都是一个维数是词表大小的one-hot列向量;

      标记序列 $\textbf y_{(1:T)} =(\textbf y_1,\textbf y_2,...,\textbf y_T)$ ,每个时刻的值都是一个维数是词表大小的one-hot列向量;

      输出序列 $\hat{\textbf y}_{(1:T)} =(\hat{\textbf y}_1,\hat{\textbf y}_2,...,\hat{\textbf y}_T)$ ,每个时刻的值都是一个维数是词表大小的列向量;

      隐层输出 $\textbf h_t\in\mathbb R^H$ ;

      隐层输入 $\textbf s_t\in\mathbb R^H$ ;

      过softmax之前输出层的输出 $\textbf z_t$ 。

(一)BPTT

      那么对于Simple RNN来说,前向传播过程如下(省略了偏置):

$$\textbf s_t=U\textbf h_{t-1}+W\textbf x_t$$

$$\textbf h_t=f (\textbf s_t)$$

$$\textbf z_t=V\textbf h_t$$

$$\hat{\textbf y}_t=\text{softmax}(\textbf z_t)$$

      其中 $f$ 是激活函数。注意,三个权重矩阵在时间维度上是共享的。这可以理解为:每个时刻都在执行相同的任务,所以是共享的。

      既然每个时刻都有输出 $\hat{\textbf y}_t$ ,那么相应地,每个时刻都会有损失。记 $t$ 时刻的损失为 $E_t$ ,那么对于样本 $\textbf x_{(1:T)}$ 来说,损失 $E$ 为

$$E=\sum_{t=1}^TE_t$$

      使用交叉熵损失函数,那么

$$E_t=-\textbf y_t^{\top}\log\hat{\textbf y}_t$$

      一、 $E$ 对 $V$ 的梯度

      下面首先求取 $E$ 对 $V$ 的梯度。根据chain rule:$\dfrac{\partial \textbf z}{\partial \textbf x}=\dfrac{\partial \textbf y}{\partial \textbf x}\dfrac{\partial \textbf z}{\partial \textbf y}$ 、$\dfrac{\partial z}{\partial X_{ij}}=(\dfrac{\partial z}{\partial\textbf y})^{\top}\dfrac{\partial\textbf y}{\partial X_{ij}}$ ,有

$$\frac{\partial E_t}{\partial V_{ij}}=(\frac{\partial E_t}{\partial\textbf z_t})^{\top}\frac{\partial\textbf z_t}{\partial V_{ij}}$$

      这里其实和BP是一样的,前一项相当于是误差项 $\delta$ ,后一项等于

$$\frac{\partial \textbf z_t}{\partial V_{ij}}=\frac{\partial V\textbf h_t}{\partial V_{ij}}=(0,...,[\textbf h_t]_j,...,0)^{\top}$$

只有第 $i$ 行非零,$[\textbf h_t]_j$ 是指 $\textbf h_t$ 的第 $j$ 个元素。参考上一篇博客的结尾部分,可知前一项等于

$$\frac{\partial E_t}{\partial\textbf z_t}=\hat{\textbf y}_t-\textbf y_t$$

(那里的求解用了一些技巧。如果用普通解法去推这个式子的话,可以参考 [6] 。)

      所以有

$$\frac{\partial E_t}{\partial V_{ij}}=[\hat{\textbf y}_t-\textbf y_t]_i[\textbf h_t]_j$$

从而有

$$\frac{\partial E_t}{\partial V}=(\hat{\textbf y}_t-\textbf y_t)\textbf h_t^{\top}=(\hat{\textbf y}_t-\textbf y_t)\otimes \textbf h_t$$

向量外积是矩阵的Kronecker积在向量下的特殊情况。因此,

$$\frac{\partial E}{\partial V}=\sum_{t=1}^T(\hat{\textbf y}_t-\textbf y_t)\otimes \textbf h_t$$

      二、 $E$ 对 $U$ 的梯度

      继续求取 $E$ 对 $U$ 的梯度。在求 $\frac{\partial E_t}{\partial U}$ 时,需要注意到一个事实,那就是不光 $t$ 时刻的隐状态与 $U$ 有关,之前所有时刻的隐状态都与 $U$ 有关。所以,根据chain rule:

$$\frac{\partial E_t}{\partial U}=\sum_{k=1}^t\frac{\partial\textbf s_k}{\partial U}\frac{\partial E_t}{\partial\textbf s_k}$$

      下面使用和之前类似的套路求解:先求对一个矩阵一个元素的梯度。

$$\frac{\partial E_t}{\partial U_{ij}}=\sum_{k=1}^t(\frac{\partial E_t}{\partial\textbf s_k})^{\top}\frac{\partial\textbf s_k}{\partial U_{ij}}$$

      前一项先定义为 $\delta_{t,k}=\dfrac{\partial E_t}{\partial\textbf s_k}$ ,对于后一项:

$$\frac{\partial\textbf s_k}{\partial U_{ij}}=\frac{\partial(U\textbf h_{k-1}+W\textbf x_k)}{\partial U_{ij}}=(0,...,[\textbf h_{k-1}]_j,...,0)^{\top}$$

只有第 $i$ 行非零,$[\textbf h_{k-1}]_j$ 是指 $\textbf h_{k-1}$ 的第 $j$ 个元素。现在来求解 $\delta_{t,k}=\dfrac{\partial E_t}{\partial\textbf s_k}$ ,使用上篇文章求 $\delta^{(l)}$ 的套路:

$$\begin{aligned}\delta_{t,k}&=\frac{\partial E_t}{\partial\textbf s_k}\\&=\frac{\partial \textbf h_k}{\partial\textbf s_{k}}\frac{\partial \textbf s_{k+1}}{\partial\textbf h_{k}}\frac{\partial E_t}{\partial\textbf s_{k+1}}\\&=\text{diag}(f‘(\textbf s_T))U^{\top}\delta_{t,k+1}\\&=f‘(\textbf s_{k})\odot (U^{\top}\delta_{t,k+1})\end{aligned}$$

一种特殊情况是当 $\delta_{t,t}$ ,有

$$\begin{aligned}\delta_{t,t}&=\frac{\partial E_t}{\partial\textbf s_t}\\&=\frac{\partial \textbf h_t}{\partial\textbf s_t}\frac{\partial \textbf z_t}{\partial\textbf h_t}\frac{\partial E_t}{\partial\textbf z_t}\\&=\text{diag}(f‘(\textbf s_{t}))V^{\top}(\hat{\textbf y}_t-\textbf y_t)\\&=f‘(\textbf s_{t})\odot (V^{\top}(\hat{\textbf y}_t-\textbf y_t))\end{aligned}$$

      所以,

$$\frac{\partial E_t}{\partial U_{ij}}=\sum_{k=1}^t[\delta_{t,k}]_i[\textbf h_{k-1}]_j$$

$$\frac{\partial E_t}{\partial U}=\sum_{k=1}^t\delta_{t,k}\textbf h_{k-1}^{\top}=\sum_{k=1}^t\delta_{t,k}\otimes\textbf h_{k-1}$$

因此,

$$\frac{\partial E}{\partial U}=\sum_{t=1}^T\sum_{k=1}^t\delta_{t,k}\otimes\textbf h_{k-1}$$ 

(二)梯度消失(gradient vanishing)与梯度爆炸(gradient exploding

      首先,

$$\frac{\partial E_t}{\partial U}=\frac{\partial \textbf h_t}{\partial U}\frac{\partial \hat{\textbf y}_t}{\partial \textbf h_t}\frac{\partial E_t}{\partial \hat{\textbf y}_t}$$

      这里的 $\dfrac{\partial \textbf h_t}{\partial U}$ 比较麻烦,是因为各个时刻共享了参数:$\textbf h_t$ 和 $\textbf h_{t-1}$ 、$U$ 有关,而 $\textbf h_{t-1}$ 又和 $\textbf h_{t-2}$ 、$U$ 有关。所以参照 [5] ,可以写成以下形式(注意 [5] 中的前向传播过程和 [4] 一样,与本文有区别):

$$\frac{\partial E_t}{\partial U}=\sum_{k=1}^t\frac{\partial \textbf h_k}{\partial U}\frac{\partial \textbf h_t}{\partial \textbf h_k}\frac{\partial \hat{\textbf y}_t}{\partial \textbf h_t}\frac{\partial E_t}{\partial \hat{\textbf y}_t}$$

其中,

$$\begin{aligned}\frac{\partial \textbf h_t}{\partial \textbf h_k}&=\prod_{i=k+1}^t\frac{\partial \textbf h_i}{\partial \textbf h_{i-1}}\\&=\prod_{i=k+1}^t\frac{\partial \textbf s_i}{\partial \textbf h_{i-1}}\frac{\partial f(\textbf s_i)}{\partial \textbf s_i}\\&=\prod_{i=k+1}^tU^{\top}\text{diag}{f‘(\textbf s_i)}\end{aligned}$$

从这个式子可以看出,当使用tanh或Logistic激活函数时,由于导数值分别在0到1之间、0到1/4之间,所以如果权重矩阵 $U$ 的范数也不很大,那么经过 $t-k$ 次传播后,$\dfrac{\partial \textbf h_t}{\partial \textbf h_k}$ 的范数会趋于0,也就导致了梯度消失问题。

      为了缓解梯度消失,可以使用ReLU、PReLU来作为激活函数,以及将 $U$ 初始化为单位矩阵(而不是用随机初始化)等方式。

      也就是说,虽然Simple RNN从理论上可以保持长时间间隔的状态之间的依赖关系,但是实际上只能学习到短期依赖关系。

      这就叫长期依赖问题,需要通过LSTM单元来缓解这个问题。

      而对于梯度爆炸问题,通常就是使用比较简单的策略,如gradient clipping一次迭代中,每个权重的梯度的平方和如果大于某个阈值,为避免权重矩阵的更新过于迅猛,那么求取一个缩放因子(阈值除以平方和),将所有的梯度乘以这个因子。

 

 

 

 

参考资料:

[1] 《神经网络与深度学习讲义》

[2] RECURRENT NEURAL NETWORKS TUTORIAL, PART 3 – BACKPROPAGATION THROUGH TIME AND VANISHING GRADIENTS

[3] BPTT算法推导

[4] On the difficulty of training RNN

[5] Rucurrent nets and LSTM

[6] LSTM的BPTT推导

 

DL4NLP——神经网络(二)循环神经网络:BPTT算法步骤整理;梯度消失与梯度爆炸