首页 > 代码库 > 对抗神经网络(Adversarial Nets)的介绍[1]
对抗神经网络(Adversarial Nets)的介绍[1]
- 对抗NN简介
- 概念介绍
- 对抗名字的由来及对抗过程
- 对抗NN的模型
- 对抗NN的模型和训练
- 判别网络D的最优值
- 模拟学习高斯分布
- 对抗NN实验结果
- 《生成对抗NN》代码的安装与运行
- 对抗网络相关论文
- 论文引用
一、对抗NN简介
大牛Ian J. Goodfellow 的2014年的《Generative Adversative Nets》第一次提出了对抗网络模型,短短两年的时间,这个模型在深度学习生成模型领域已经取得了不错的成果。论文提出了一个新的框架,可以利用对抗过程估计生成模型,相比之前的算法,可以认为是在无监督表示学习(Unsuperivised representation learning)上一个突破,现在主要的应用是用其生成自然图片(natural images)。
二、概念介绍
机器学习两个模型——生成模型和判别模型。
- 生成模型(Generative):学习到的是对于所观察数据的联合分布 比如2-D: p(x,y).
判别模型:学习到的是条件概率分布p(y|x),即学习到的是观察变量x的前提下的非观察变量的分布情况。
通俗的说,我们想通过生成模型来从数据中学习到分布情况,来生成新的数据。比如从大量的图片中学习,然后生成一张新的Photo.
而对于判别模型,最经典的应用,比如监督学习,那么对于分类问题,我想知道输入x,输出y的情况,那么y的值可以理解为数据的label。
而其中的对抗神经网络就是一个判别模型(Discriminative, D)和一个生成模型(Generative ,G)的组成的。
三、对抗名字的由来及对抗过程
刚才介绍了对抗网络其实是一个D和一个G组成的,那么G和D之间是如何对抗的呢?
先看以下一个场景:
- D是银行的Teller
G是一个Crook,专门制造假币。
那么其中的对抗过程就是,对于D来说,不断的学习,来进行真币的判断,G则是不断学习,制造更像真币的假币,来欺骗D,而最后的训练结果则是——D可以很好的区分真假币,但是G制造了“如假包换”的假币,而D分辨不出。
而对于对抗网络来说,D和G都是一个神经网络模型——MLP,那么D(判别模型)的输出是一个常量,这个常量表示“来自真币”的可能性。而对于G的输出则是一组向量,而这个向量表示的就是”假币”。
四、对抗NN的模型
图片1中的Z是G的输入,一般情况下是高斯随机分布生成的数据;其中G的输出是G(z),对于真实的数据,一般都为图片,将分布变量用X来表示。那么对于D的输出则是判断来自X的可能性,是一个常量。
五、对抗NN的训练和优化
对于G来说,要不断的欺骗D,那么也就是:
max log(D(G(z))) 目标函数1
对于D来说,要不断的学习防止被D欺骗,那么也就是:
max log(D(x)) + log(1 - D(G(z))) 目标函数2
使用梯度下降法(GD)训练,那么梯度如下。
对于目标函数1来说:
对于目标函数2来说:
训练过程
论文[1]给出了Algorithm 1,详细内容请查看原文,就是先进行训练D,然后训练D。其中论文也给出了公式来证明算法的可收敛性。
训练的几个trick:
- 论文提到的dropout的使用(应该是maxout layer)
- 每次进行多次D的训练,在进行G的训练,防止过拟合。
- 在训练之前,可以先进行预训练。
六、判别网络D的最优值
将X的概率密度分布函数(pdf)定义为
将G(Z)的pdf定义为
那么对于每一次训练,G如果固定的话,最优的输出D的值可以认为是
而且,最后训练的结果,是D=1/2=0.5。即此时有:
关于此详细证明可以查看原文。
七、对抗NN的实验结果
论文1用到的数据集包括,MNIST a)、TFD b)、CIFAR-10 c) d),数据集。对于不同的数据集,原文用到了不同的网络模型。
模型如下。
数据集 | G模型 | D模型 |
---|---|---|
mnist | relu+sigmoid 激活函数 | maxout+sigmoid |
tfd | 没有提到 | 没有提到 |
CIFAR-10 c) | 全连接+激活函数 | maxout+sigmoid |
CIFAR-10 d) | 反卷积层+激活函数 | maxoutconv+sigmoid |
详细模型介绍请查看开源项目中的yaml文件
https://github.com/goodfeli/adversarial
八、模拟学习高斯分布
论文给出的一张图。如下:
- D , blue , dashed line
- X , black , dotted line
- G , green , solid line
其中是通过对抗网络,让G(z)学习到x的分布,而x是符合高斯分布的,z是均匀分布。其中从(a)到(d)就是不断学习的过程,刚开始,G(z)和X的pdf是不吻合的,因为刚开始G(z)不可能一下就从随机变量中生成目标分布的数据。不过,最后,我们也可以看到(d)是最后学习到图像,其中下边两条平行线,z经过G()的映射已经和x的分布完全吻合(当然这是一个理想的情况),而且,D的输出是一条直线,就像上文提到的,D() = 1/2 一个常量。
Tensorflow 相关代码
(1)Discriminator’s loss
batch=tf.Variable(0)
obj_d=tf.reduce_mean(tf.log(D1)+tf.log(1-D2))
opt_d=tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_d,global_step=batch,var_list=theta_d)
(2)Generator’s loss
batch=tf.Variable(0)
obj_g=tf.reduce_mean(tf.log(D2))
opt_g=tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_g,global_step=batch,var_list=theta_g)
(3)Training Algorithms 1 , GoodFellow et al. 2014
for i in range(TRAIN_ITERS):
x= np.random.normal(mu,sigma,M)
z= np.random.random(M)
sess.run(opt_d, {x_node: x, z_node: z}) //先训练D
z= np.random.random(M)
sess.run(opt_g, {z_node: z}) //在训练G
以上代码是Tensorflow实现的用对抗NN生成高斯分布的例子。
九、大牛Good fellow 论文代码的安装与运行
对抗网络的作者Goodfellow也开源了自己的代码。
(1)项目链接
Adversarial链接
(2)下载与依赖库的安装
- 项目依赖pylearn2 ,要先安装pylearn2
- 本人git clone 了 pylearn2,adversarial 两个项目。添加了三个环境变量(根据自己路径添加)。
export PYLEARN2_VIEWER_COMMAND="eog --new-instance"
export PYLEARN2_DATA_PATH=/home/data
export PYTHONPATH=/home/code
- 其他python 依赖库可以通过pip或者apt-get安装。
(3)训练和测试
- 调用pylearn2的 train.py 和mnist.yaml进行训练。
pylearn2/scripts/train.py ./adversarial/mnist.yaml
测试如下
- 在adversarial 目录下运行
python show_samples_mnist_paper.py mnist.pkl
十、对抗网络相关论文和应用
博主做了一个开源项目,收集了对抗网络相关的paper和论文。
欢迎star和Contribution。
https://github.com/zhangqianhui/AdversarialNetsPapers
对抗NN的应用。这些应用都可以从我的开源项目中找到。
(1)论文[2]其中使用了CNN,用于图像生成,其中将D用于分类,取得了不错的效果。
(2)论文[3]将对抗NN用在了视频帧的预测,解决了其他算法容易产生fuzzy 块等问题。
(3)论文[4]将对抗NN用在了图片风格化处理可视化操作应用上。
十一、论文引用
[1]Generative Adversarial Networks.Goodfellow.
[2]Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.Alec Radford.
[3]Deep multi-scale video prediction beyond mean square error.Michael Mathieu.
[4]Generative Visual Manipulation on the Natural Image Manifold.Jun-Yan Zhu.ECCV 2016.
对抗神经网络(Adversarial Nets)的介绍[1]