首页 > 代码库 > 对抗神经网络(Adversarial Nets)的介绍[1]

对抗神经网络(Adversarial Nets)的介绍[1]

  • 对抗NN简介
  • 概念介绍
  • 对抗名字的由来及对抗过程
  • 对抗NN的模型
  • 对抗NN的模型和训练
  • 判别网络D的最优值
  • 模拟学习高斯分布
  • 对抗NN实验结果
  • 《生成对抗NN》代码的安装与运行
  • 对抗网络相关论文
  • 论文引用

一、对抗NN简介

大牛Ian J. Goodfellow 的2014年的《Generative Adversative Nets》第一次提出了对抗网络模型,短短两年的时间,这个模型在深度学习生成模型领域已经取得了不错的成果。论文提出了一个新的框架,可以利用对抗过程估计生成模型,相比之前的算法,可以认为是在无监督表示学习(Unsuperivised representation learning)上一个突破,现在主要的应用是用其生成自然图片(natural images)。

二、概念介绍

机器学习两个模型——生成模型和判别模型。

  • 生成模型(Generative):学习到的是对于所观察数据的联合分布 比如2-D: p(x,y).
  • 判别模型:学习到的是条件概率分布p(y|x),即学习到的是观察变量x的前提下的非观察变量的分布情况。

    通俗的说,我们想通过生成模型来从数据中学习到分布情况,来生成新的数据。比如从大量的图片中学习,然后生成一张新的Photo.
    而对于判别模型,最经典的应用,比如监督学习,那么对于分类问题,我想知道输入x,输出y的情况,那么y的值可以理解为数据的label。

而其中的对抗神经网络就是一个判别模型(Discriminative, D)和一个生成模型(Generative ,G)的组成的。

三、对抗名字的由来及对抗过程

刚才介绍了对抗网络其实是一个D和一个G组成的,那么G和D之间是如何对抗的呢?
先看以下一个场景:

  • D是银行的Teller
  • G是一个Crook,专门制造假币。

    那么其中的对抗过程就是,对于D来说,不断的学习,来进行真币的判断,G则是不断学习,制造更像真币的假币,来欺骗D,而最后的训练结果则是——D可以很好的区分真假币,但是G制造了“如假包换”的假币,而D分辨不出。

而对于对抗网络来说,D和G都是一个神经网络模型——MLP,那么D(判别模型)的输出是一个常量,这个常量表示“来自真币”的可能性。而对于G的输出则是一组向量,而这个向量表示的就是”假币”。

四、对抗NN的模型

技术分享

图片1

图片1中的Z是G的输入,一般情况下是高斯随机分布生成的数据;其中G的输出是G(z),对于真实的数据,一般都为图片,将分布变量用X来表示。那么对于D的输出则是判断来自X的可能性,是一个常量。

五、对抗NN的训练和优化

对于G来说,要不断的欺骗D,那么也就是:

max log(D(G(z)))                           目标函数1

对于D来说,要不断的学习防止被D欺骗,那么也就是:

max log(D(x)) + log(1 - D(G(z)))           目标函数2

使用梯度下降法(GD)训练,那么梯度如下。

对于目标函数1来说:

技术分享

对于目标函数2来说:

技术分享

训练过程

论文[1]给出了Algorithm 1,详细内容请查看原文,就是先进行训练D,然后训练D。其中论文也给出了公式来证明算法的可收敛性。

训练的几个trick:

  • 论文提到的dropout的使用(应该是maxout layer)
  • 每次进行多次D的训练,在进行G的训练,防止过拟合。
  • 在训练之前,可以先进行预训练。

六、判别网络D的最优值

将X的概率密度分布函数(pdf)定义为技术分享
将G(Z)的pdf定义为技术分享

那么对于每一次训练,G如果固定的话,最优的输出D的值可以认为是

技术分享

而且,最后训练的结果,是D=1/2=0.5。即此时有:

技术分享

关于此详细证明可以查看原文。

七、对抗NN的实验结果

论文1用到的数据集包括,MNIST a)、TFD b)、CIFAR-10 c) d),数据集。对于不同的数据集,原文用到了不同的网络模型。

技术分享

图片2-实验结果

模型如下。

数据集 G模型 D模型
mnist relu+sigmoid 激活函数 maxout+sigmoid
tfd 没有提到 没有提到
CIFAR-10 c) 全连接+激活函数 maxout+sigmoid
CIFAR-10 d) 反卷积层+激活函数 maxoutconv+sigmoid

详细模型介绍请查看开源项目中的yaml文件

https://github.com/goodfeli/adversarial

八、模拟学习高斯分布

论文给出的一张图。如下:

技术分享

  • D , blue , dashed line
  • X , black , dotted line
  • G , green , solid line

其中是通过对抗网络,让G(z)学习到x的分布,而x是符合高斯分布的,z是均匀分布。其中从(a)到(d)就是不断学习的过程,刚开始,G(z)和X的pdf是不吻合的,因为刚开始G(z)不可能一下就从随机变量中生成目标分布的数据。不过,最后,我们也可以看到(d)是最后学习到图像,其中下边两条平行线,z经过G()的映射已经和x的分布完全吻合(当然这是一个理想的情况),而且,D的输出是一条直线,就像上文提到的,D() = 1/2 一个常量。

Tensorflow 相关代码

(1)Discriminator’s loss

batch=tf.Variable(0)
obj_d=tf.reduce_mean(tf.log(D1)+tf.log(1-D2))
opt_d=tf.train.GradientDescentOptimizer(0.01)
              .minimize(1-obj_d,global_step=batch,var_list=theta_d)

(2)Generator’s loss

batch=tf.Variable(0)
obj_g=tf.reduce_mean(tf.log(D2))
opt_g=tf.train.GradientDescentOptimizer(0.01)
              .minimize(1-obj_g,global_step=batch,var_list=theta_g)

(3)Training Algorithms 1 , GoodFellow et al. 2014

for i in range(TRAIN_ITERS):
    x= np.random.normal(mu,sigma,M)    
    z= np.random.random(M)      
    sess.run(opt_d, {x_node: x, z_node: z})    //先训练D 
    z= np.random.random(M)     
    sess.run(opt_g, {z_node: z})               //在训练G

以上代码是Tensorflow实现的用对抗NN生成高斯分布的例子。

九、大牛Good fellow 论文代码的安装与运行

对抗网络的作者Goodfellow也开源了自己的代码。

(1)项目链接

Adversarial链接

(2)下载与依赖库的安装

  • 项目依赖pylearn2 ,要先安装pylearn2
  • 本人git clone 了 pylearn2,adversarial 两个项目。添加了三个环境变量(根据自己路径添加)。
export PYLEARN2_VIEWER_COMMAND="eog --new-instance"
export PYLEARN2_DATA_PATH=/home/data
export PYTHONPATH=/home/code
  • 其他python 依赖库可以通过pip或者apt-get安装。

(3)训练和测试

  • 调用pylearn2的 train.py 和mnist.yaml进行训练。
pylearn2/scripts/train.py ./adversarial/mnist.yaml

测试如下

  • 在adversarial 目录下运行
python show_samples_mnist_paper.py mnist.pkl

十、对抗网络相关论文和应用

博主做了一个开源项目,收集了对抗网络相关的paper和论文。
欢迎star和Contribution。

https://github.com/zhangqianhui/AdversarialNetsPapers

对抗NN的应用。这些应用都可以从我的开源项目中找到

(1)论文[2]其中使用了CNN,用于图像生成,其中将D用于分类,取得了不错的效果。

(2)论文[3]将对抗NN用在了视频帧的预测,解决了其他算法容易产生fuzzy 块等问题。

技术分享

(3)论文[4]将对抗NN用在了图片风格化处理可视化操作应用上。

技术分享

十一、论文引用

[1]Generative Adversarial Networks.Goodfellow.
[2]Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.Alec Radford.
[3]Deep multi-scale video prediction beyond mean square error.Michael Mathieu.
[4]Generative Visual Manipulation on the Natural Image Manifold.Jun-Yan Zhu.ECCV 2016.

<script type="text/javascript"> $(function () { $(‘pre.prettyprint code‘).each(function () { var lines = $(this).text().split(‘\n‘).length; var $numbering = $(‘
    ‘).addClass(‘pre-numbering‘).hide(); $(this).addClass(‘has-numbering‘).parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($(‘
  • ‘).text(i)); }; $numbering.fadeIn(1700); }); }); </script>

    对抗神经网络(Adversarial Nets)的介绍[1]