为什么要先进行案例研究?
没有比较好的数学基础,直接接触深度学习会非常抽象,所以这里我们先通过一个预测 Pokemon Go 的 Combat Power (CP) 值的案例,打开深度学习的大门。
Regression (回归)
应用举例(预测Pokemon Go 进化后的战斗力)
比如估计一只神奇宝贝进化后的 CP 值(战斗力)。
下面是一只妙蛙种子,可以进化为妙蛙草,现在的CP值是14,我们想估计进化后的CP值是多少;进化需要糖果,好处就是如果它进化后CP值不满意,那就不用浪费糖果来进化它了,可以选择性价比高的神奇宝贝。
输入用了一些不同的 x<script type="math/tex" id="MathJax-Element-1">x</script> 来代表不同的属性,比如战斗力用 xcp<script type="math/tex" id="MathJax-Element-2">x_{cp}</script> 来表示,物种 xs<script type="math/tex" id="MathJax-Element-3">x_{s}</script> 来表示…
输出就是进化后的CP值
三个步骤
上一篇提到了机器学习的三个步骤:
Step1.确定一组函数(Model)。
Step2.将训练集对函数集进行训练。
Step3.挑选出“最好”的函数 f?<script type="math/tex" id="MathJax-Element-4">f^{*}</script>
然后就可以使用 f?<script type="math/tex" id="MathJax-Element-5">f^{*}</script> 来对新的测试集进行检测。
Step1: Model
这个model 应该长什么样子呢,先写一个简单的:我们可以认为进化后的CP值 y<script type="math/tex" id="MathJax-Element-6">y</script> 等于进化前的CP值 xcp<script type="math/tex" id="MathJax-Element-7">x_{cp}</script> 乘以一个参数 w<script type="math/tex" id="MathJax-Element-8">w</script> 再加上一个参数 b<script type="math/tex" id="MathJax-Element-9">b</script> 。
y=b+w?xcp(1?1)
<script type="math/tex; mode=display" id="MathJax-Element-316">
y = b + w \cdot x_{cp} \qquad (1-1)
</script>
w<script type="math/tex" id="MathJax-Element-317">w</script> 和 b<script type="math/tex" id="MathJax-Element-318">b</script> 是参数,可以是任何数值。
可以有
f1:y=10.0+9.0?xcpf2:y=9.8+9.2?xcpf3:y=?0.8?1.2?xcp
<script type="math/tex; mode=display" id="MathJax-Element-319">
f_{1}: y = 10.0 + 9.0 \cdot x_{cp}\f_{2}: y = 9.8 + 9.2 \cdot x_{cp}\f_{3}: y = -0.8 -1.2 \cdot x_{cp}
</script>
这个函数集中可以有无限多的 function。所以我们用 y=b+w?xcp<script type="math/tex" id="MathJax-Element-320">y = b + w \cdot x_{cp} </script> 代表这些 function 所成的集合。还有比如上面的 f3<script type="math/tex" id="MathJax-Element-321">f_{3}</script> ,明显是不正确的,因为CP值有个条件都是正的,那乘以 ?1.2<script type="math/tex" id="MathJax-Element-322">-1.2</script> 就变成负的了,所以我们接着就要根据训练集来找到,这个 function set 里面,哪个是合理的 function。
我们将式1-1 称作 Linear model, Linear model 形式为:
y=b+∑wixi
<script type="math/tex; mode=display" id="MathJax-Element-323">
y = b + \sum w_{i}x_{i}
</script>
xi<script type="math/tex" id="MathJax-Element-324">x_{i}</script> 就是神奇宝贝的各种不同的属性,身高、体重等等,我们将这些称之为 “feature(特征)”;wi<script type="math/tex" id="MathJax-Element-325">w_{i}</script> 称为 weight(权重),b 称为 bias(偏差)。
Step2: 方程的好坏
现在就需要搜集训练集,这里的数据集是 Supervised 的,所以需要 function 的输入和输出(数值),举例抓了一只杰尼龟,进化前的CP值为612,用 x1<script type="math/tex" id="MathJax-Element-326">x^{1}</script> 代表这只杰尼龟进化前的CP值,即用上标标示一个完整对象的编号;进化后的CP值为 979,用 y? 1<script type="math/tex" id="MathJax-Element-327">\hat{y}^{1}</script> 表示进化后的CP值,用 hat(字母头顶的上尖符号)来表示这是一个正确的值,是实际观察到function该有的输出。
下面我们来看真正的数据集(来源 Source: https://www.openintro.org/stat/data/?data=http://www.mamicode.com/pokemon)
来看10只神奇宝贝的真实数据,x<script type="math/tex" id="MathJax-Element-328">x</script> 轴代表进化前的CP值,y<script type="math/tex" id="MathJax-Element-329">y</script> 轴代表进化后的CP值。
有了训练集,为了评价 function 的好坏,我们需要定义一个新的函数,称为 Loss function (损失函数),定义如下:
Loss function L<script type="math/tex" id="MathJax-Element-330">L</script> :
input: a function, output: how bad it is
Loss function是比较特别的函数,是函数的函数,因为它的输入是一个函数,而输出是表示输入的函数有多不好。 可以写成下面这种形式:
L(f)=L(w,b)
<script type="math/tex; mode=display" id="MathJax-Element-331">
L(f) = L(w, b)
</script>
损失函数是由一组参数 w和b决定的,所以可以说损失函数是在衡量一组参数的好坏。
这里用比较常见的定义形式:
L(f)=L(w,b)=∑n=110(y? n?(b+w?xncp))2(1?2)
<script type="math/tex; mode=display" id="MathJax-Element-332">
L(f) = L(w, b) =\sum_{n=1}^{10}\left(\hat{y}^{n} -(b + w\cdot x_{cp}^{n})\right)^{2} \qquad (1-2)
</script>
将实际的数值 y? n<script type="math/tex" id="MathJax-Element-333">\hat{y}^{n}</script> 减去 估测的数值 b+w?xncp<script type="math/tex" id="MathJax-Element-334">b + w\cdot x_{cp}^{n}</script>,然后再给平方,就是 Estimation error(估测误差,总偏差);最后将估测误差加起来就是我们定义的损失函数。
这里不取各个偏差的代数和∑10n=1y? n?(b+w?xncp)<script type="math/tex" id="MathJax-Element-335">\sum_{n=1}^{10}\hat{y}^{n} -(b + w\cdot x_{cp}^{n})</script> 作为总偏差,这是因为这些偏差(y? i?(b+w?xicp)<script type="math/tex" id="MathJax-Element-336">\hat{y}^{i} -(b + w\cdot x_{cp}^{i})</script>)本身有正有负,如果简单地取它们的代数和,就可能互相抵消,这是虽然偏差的代数和很小,却不能保证各个偏差都很小。所以按照式1-2,是这些偏差的平方和最小,就可以保证每一个偏差都很小。
为了更加直观,来对损失函数进行作图:
图上每个点都代表一个方程,比如红色的那个点代表 y=?180?2?xcp<script type="math/tex" id="MathJax-Element-337">y=-180-2\cdot x_{cp}</script> 。颜色代表用这个点的方程得到的损失函数有多不好,颜色越偏红色,代表数值越大,越偏蓝色蓝色,代表方程越好。最好的方程就是图中叉叉标记的点。
Step3:最好的方程
定好了损失函数,可以衡量每一个方程的好坏,接下来需要从函数集中挑选一个最好的方程。将这个过程数学化:
f?=argminfL(f)
<script type="math/tex; mode=display" id="MathJax-Element-338">
f^{*}=\arg \min_{f} L(f) \</script>
w?,b?=argminw,bL(w,b)=argminw,b∑n=110(y? n?(b+w?xncp))2(1?3)
<script type="math/tex; mode=display" id="MathJax-Element-339">
w^{*},b^{*} = \arg \min_{w,b} L(w,b)=\arg \min_{w,b} \sum_{n=1}^{10}\left(\hat{y}^{n} -(b + w\cdot x_{cp}^{n})\right)^{2} \qquad (1-3)
</script>
由于这里举例的特殊性,对于式1-3,直接使用最小二乘法即可解出最优的 w 和 b,使得总偏差最小。
简单说一下最小二乘法,对于二元函数 f(x,y)<script type="math/tex" id="MathJax-Element-340">f(x,y)</script>,函数的极值点必为 ?f?x<script type="math/tex" id="MathJax-Element-341">\frac{\partial f}{\partial x}</script> 及?f?y<script type="math/tex" id="MathJax-Element-342">\frac{\partial f}{\partial y}</script> 同时为零或至少有一个偏导数不存在的点;这是极值的必要条件。用这个极值条件可以解出w 和 b。(详情请参阅《数学分析,第三版下册,欧阳光中 等编》第十五章,第一节)
但这里会使用另外一种做法,Gradient Descent(最速下降法),最速下降法不光能解决式1-3 这一种问题;实际上只要 L<script type="math/tex" id="MathJax-Element-343">L</script> 是可微分的,都可以用最速下降法来处理。
Gradient Descent(梯度下降法)
简单来看一下梯度下降法的做法。
考虑只有一个参数 w<script type="math/tex" id="MathJax-Element-344">w</script> 的损失函数,随机的选取一个初始点,计算 w=w0<script type="math/tex" id="MathJax-Element-345">w = w^{0}</script> 时 L<script type="math/tex" id="MathJax-Element-346">L</script> 对 w<script type="math/tex" id="MathJax-Element-347">w</script> 的微分,然后顺着切线下降的方向更改 w<script type="math/tex" id="MathJax-Element-348">w</script> 的值(因为这里是求极小值),即斜率为负,增加w<script type="math/tex" id="MathJax-Element-349">w</script> ;斜率为正,减小w<script type="math/tex" id="MathJax-Element-350">w</script> .
那么每次更改 w<script type="math/tex" id="MathJax-Element-351">w</script> ,更改多大,用 ηdLdw|w=w0<script type="math/tex" id="MathJax-Element-352">\eta \frac{\mathrm{d}L}{\mathrm{d}w} |_{w=w^{0}}</script> 表示,η<script type="math/tex" id="MathJax-Element-353">\eta</script> 被称为“learning rate”学习速率。
由于这里斜率是负的,所以是 w0?ηdLdw|w=w0<script type="math/tex" id="MathJax-Element-354">w^{0} - \eta \frac{\mathrm{d}L}{\mathrm{d}w} |_{w=w^{0}}</script> ,得到 w1<script type="math/tex" id="MathJax-Element-355">w^{1}</script>;接着就是重复上述步骤。
直到找到一个点,这个点的斜率为0。但是例子中的情况会比较疑惑,这样的方法很可能找到的只是局部极值,并不是全局极值,但这是由于我们例子的原因,针对回归问题来说,是不存在局部极值的,只有全局极值。所以这个方法还是可以使用。
下面来看看两个参数的问题。
两个参数的区别就是每次需要对两个参数求偏微分,然后同理更新参数的值。
关于梯度可以参阅《数学分析,第三版下册,欧阳光中 等编》,第十四章第六节。也可以大概看看百度百科又或者wikipedia
将上述做法可视化:
同理梯度下降的缺陷如下图:
可能只是找到了局部极值,但是对于线性回归,可以保证所选取的损失函数式1-2是 convex(凸的,即只存在唯一极值)。上图右边就是损失函数的等高线图,可以看出是一圈一圈向内减小的。
结果怎么样呢?
将求出的结果绘图如下
可以计算出训练集上的偏差绝对值之和为 31.9
但真正关心的并不是在训练集上的偏差,而是Generalization的情况,就是需要在新的数据集(测试集)上来计算偏差。如下图:
使用十个新的神奇宝贝的数据作为测试集计算出偏差绝对值之和为35.
接下来考虑是否能够做的更好,可能并不只是简单的直线,考虑其他model的情况:
比如重新设计一个model,多一个二次项,来求出参数,得到Average Error为15.4,在训练集上看起来更好了。在测试集上得出的Average Error是18.4,确实是更好的Model。
再考虑三次项:
得到的结果看起来和二次项时候的结果差别不大,稍微好一点点。也可以看到w3<script type="math/tex" id="MathJax-Element-356">w_{3}</script>已经非常小了,说明三次项影响已经不大了。
再考虑四次项:
此时在训练集上可以做的更好,但是测试集的结果变差了。
再考虑五次项:
可以看到测试集的结果非常差。
Overfitting(过拟合,过度学习)
将训练集上的Average Error变化进行作图:
可以看到训练集上的 Average Error 逐渐变小。
上面的那些model,高次项是包含低次项的function。理论上确实次幂越高越复杂的方程,可以让训练集的结果越低。但加上测试集的结果:
观察得出结果:虽然越复杂的model可以在训练集上得到更好的结果,但越复杂的model并不一定在测试集上有好的结果。这个结论叫做“Overfitting(过拟合)”。
如果此时要选model的话,最好的选择就是三次项式子的model。
实际生活中典型的学驾照,学驾照的时候在驾校的训练集上人们可以做的很好,但上路之后真正的测试集就完全无法驾驭。这里只是举个训练集很好,而测试集结果很差的例子^_^
如果数据更多会怎样?
考虑60只神奇宝贝的数据
可以看出物种也是一个关键性的因素,只考虑进化前的CP值是太局限的,刚才的model就设计的不太好。
新的model如下
将这个model写成linear model的形式:
来看做出来的结果:
不同种类的神奇宝贝用的参数不同,用颜色区分。此时model在训练集上可以做的更好,在测试集上的结果也是比之前的18.1更好。
还有其他因素的影响吗?
比如对身高,体重,生命值进行绘图:
重新设计model:
考虑上生命值(xhp<script type="math/tex" id="MathJax-Element-357">x_{hp}</script>)、高度(xh<script type="math/tex" id="MathJax-Element-358">x_{h}</script>)、重量(xw<script type="math/tex" id="MathJax-Element-359">x_{w}</script>)
这么复杂的model,理论上训练集上可以得到更好的结果,实际为1.9,确实是更低。但是测试集的结果就过拟合了。
Regularization(正则化)
对于上面那么多参数结果并不理想的情况,这里进行正则化处理,将之前的损失函数进行修改:
y=b+∑wixi(1?4)L(f)=L(w,b)=∑n=110(y? n?(b+w?xncp))2+λ∑(wi)2(1?5)
<script type="math/tex; mode=display" id="MathJax-Element-360">
y = b + \sum w_{i}x_{i} \qquad (1-4)\L(f) = L(w, b) =\sum_{n=1}^{10}\left(\hat{y}^{n} -(b + w\cdot x_{cp}^{n})\right)^{2} + \lambda \sum (w_{i})^{2} \qquad (1-5)
</script>
式1-5 中多加了一项: λ∑(wi)2<script type="math/tex" id="MathJax-Element-361">\lambda \sum (w_{i})^{2}</script> ,结论是wi<script type="math/tex" id="MathJax-Element-362">w_{i}</script>越小,则方程(式1-4)就越好。还可以说当 wi<script type="math/tex" id="MathJax-Element-363">w_{i}</script> 越小,则方程越平滑。
平滑的意思是当输入变化时,输出对输入的变化不敏感。比如式1-5 中输入增加了 Δxi<script type="math/tex" id="MathJax-Element-364">\Delta x_{i}</script> 则输入就增加了 wiΔxi<script type="math/tex" id="MathJax-Element-365">w_{i}\Delta x_{i}</script> ,可以看出当wi<script type="math/tex" id="MathJax-Element-366">w_{i}</script>越小,输出变化越不明显。还比如测试集的输入有一些噪音数据,越平滑的方程就会受到更小的影响。
上图是对 λ<script type="math/tex" id="MathJax-Element-367">\lambda</script>进行调整得出的结果。当 λ<script type="math/tex" id="MathJax-Element-368">\lambda</script> 越大的时候, λ∑(wi)2<script type="math/tex" id="MathJax-Element-369">\lambda \sum (w_{i})^{2}</script> 这一项的影响力越大,所以当λ<script type="math/tex" id="MathJax-Element-370">\lambda</script> 越大的时候,方程越平滑。
训练集上得到的结果是:当 λ<script type="math/tex" id="MathJax-Element-371">\lambda</script> 越大的时候,在训练集上得到的Error 是越大的。这是合理的现象,因为当 λ<script type="math/tex" id="MathJax-Element-372">\lambda</script> 越大的时候,就越倾向于考虑 w<script type="math/tex" id="MathJax-Element-373">w</script> 本身值,减少考虑error。但是测试集上得到的error 是先减小又增大的。这里喜欢比较平滑的function,因为上面讲到对于噪音数据有很好的鲁棒性,所以开始增加 λ<script type="math/tex" id="MathJax-Element-374">\lambda</script> 的时候性能是越来越好;但是又不喜欢太平滑的function,最平滑的function就是一条水平线了,那就相当于什么都没有做,所以太平滑的function又会得到糟糕的结果。
所以最后这件事情就是找到最合适的 λ<script type="math/tex" id="MathJax-Element-375">\lambda</script> ,此时带进式1-5 求出b<script type="math/tex" id="MathJax-Element-376">b</script> 和 wi<script type="math/tex" id="MathJax-Element-377">w_{i}</script>,得到的function就是最优的function。
对于Regularization 的时候,多加的一项:λ∑(wi)2<script type="math/tex" id="MathJax-Element-378">\lambda \sum (w_{i})^{2}</script>,并没有考虑 b<script type="math/tex" id="MathJax-Element-379">b</script> ,是因为期望得到平滑的function,但bias这项并不影响平滑程度,它只是将function上下移动,跟function的平滑程度是没有关系的。
总结
- Pokemon:原始的CP值极大程度的决定了进化后的CP值,但可能还有其他的一些因素。
- Gradient descent:梯度下降的做法;后面会讲到它的理论依据和要点。
- Overfitting和Regularization:过拟合和正则化,主要介绍了表象;后面会讲到更多这方面的理论
新博客地址:http://yoferzhang.com/post/20170326ML02Regression
<script type="text/javascript"> $(function () { $(‘pre.prettyprint code‘).each(function () { var lines = $(this).text().split(‘\n‘).length; var $numbering = $(‘