首页 > 代码库 > 机器学习:利用卷积神经网络实现图像风格迁移 (一)

机器学习:利用卷积神经网络实现图像风格迁移 (一)

相信很多人都对之前大名鼎鼎的 Prisma 早有耳闻,Prisma 能够将一张普通的图像转换成各种艺术风格的图像,今天,我们将要介绍一下Prisma 这款软件背后的算法原理。就是发表于 2016 CVPR 一篇文章,

“ Image Style Transfer Using Convolutional Neural Networks”

算法的流程图主要如下:

技术分享

总得来说,就是利用一个训练好的卷积神经网络 VGG-19,这个网络在ImageNet 上已经训练过了。

给定一张风格图像 a<script type="math/tex" id="MathJax-Element-4239">a</script> 和一张普通图像 p<script type="math/tex" id="MathJax-Element-4240">p</script>,风格图像经过VGG-19 的时候在每个卷积层会得到很多 feature maps, 这些feature maps 组成一个集合 A<script type="math/tex" id="MathJax-Element-4241">A</script>,同样的,普通图像 p<script type="math/tex" id="MathJax-Element-4242">p</script> 通过 VGG-19 的时候也会得到很多 feature maps,这些feature maps 组成一个集合 P<script type="math/tex" id="MathJax-Element-4243">P</script>,然后生成一张随机噪声图像 x<script type="math/tex" id="MathJax-Element-4244">x</script>, 随机噪声图像 x<script type="math/tex" id="MathJax-Element-4245">x</script> 通过VGG-19 的时候也会生成很多feature maps,这些 feature maps 构成集合 G<script type="math/tex" id="MathJax-Element-4246">G</script> 和 F<script type="math/tex" id="MathJax-Element-4247">F</script> 分别对应集合 A<script type="math/tex" id="MathJax-Element-4248">A</script> 和 P<script type="math/tex" id="MathJax-Element-4249">P</script>, 最终的优化函数是希望调整 x<script type="math/tex" id="MathJax-Element-4250">x</script> 让 随机噪声图像 x<script type="math/tex" id="MathJax-Element-4251">x</script> 最后看起来既保持普通图像 p<script type="math/tex" id="MathJax-Element-4252">p</script> 的内容, 又有一定的风格图像 a<script type="math/tex" id="MathJax-Element-4253">a</script> 的风格。

content representation

在建立目标函数之前,我们需要先给出一些定义: 在CNN 中, 假设某一 layer 含有 Nl<script type="math/tex" id="MathJax-Element-4359">N_{l}</script> 个 filters, 那么将会生成 Nl<script type="math/tex" id="MathJax-Element-4360">N_{l}</script> 个 feature maps,每个 feature map 的维度为 Ml<script type="math/tex" id="MathJax-Element-4361">M_{l}</script> , Ml<script type="math/tex" id="MathJax-Element-4362">M_{l}</script> 是 feature map 的 高与宽的乘积。所以每一层 feature maps 的集合可以表示为 FlRNl×Ml<script type="math/tex" id="MathJax-Element-4363">F^{l} \in R^{N_{l} \times M_{l}}</script> , Flij<script type="math/tex" id="MathJax-Element-4364">F_{ij}^{l}</script> 表示第 i<script type="math/tex" id="MathJax-Element-4365">i</script>个 filter在 position j<script type="math/tex" id="MathJax-Element-4366">j</script> 上的 activation。

所以,我们可以给出 content 的 cost function:

Lcontent(p,x,l)=12ij(Flij?Plij)
<script type="math/tex; mode=display" id="MathJax-Element-4350">L_{content} (p,x,l) =\frac{1}{2}\sum_{ij} (F_{ij}^{l}-P_{ij}^{l})</script>

style representation

为了建立风格的representation,我们先利用 Gram matrix 去表示每一层各个 feature maps 之间的关系,GlRNl×Nl<script type="math/tex" id="MathJax-Element-4448">G^{l} \in R^{N_{l} \times N_{l}}</script> , Glij<script type="math/tex" id="MathJax-Element-4449">G_{ij}^{l}</script> 是 feature maps i,j<script type="math/tex" id="MathJax-Element-4450">i, j</script> 的内积:

Glij=kFlikFljk
<script type="math/tex; mode=display" id="MathJax-Element-4493"> G_{ij}^{l} =\sum_{k} F_{ik}^{l} F_{jk}^{l} </script>

利用 Gram matrix,我们可以建立每一层的关于 style 的 cost :

El=14N2lM2li,j(Glij?Alij)2
<script type="math/tex; mode=display" id="MathJax-Element-4656"> E_{l} =\frac{1}{4N_{l}^{2}M_{l}^{2}} \sum_{i,j} (G_{ij}^{l} - A_{ij}^{l})^{2} </script>

结合所有层,可以得到总的cost

Lstyle(a,x)=l=0LwlEl
<script type="math/tex; mode=display" id="MathJax-Element-4657"> L_{style} (a, x)= \sum_{l=0}^{L} w_{l} E_{l} </script>

最后将 content 和 style 的 cost 相结合,最终可以得到:

Ltotal(p,a,x)=αLcontent(p,x)+βLstyle(a,x)
<script type="math/tex; mode=display" id="MathJax-Element-5047"> L_{total} (p,a,x) =\alpha L_{content} (p,x) + \beta L_{style} (a, x) </script>

α,β<script type="math/tex" id="MathJax-Element-5048">\alpha , \beta</script> 表示权值,在建立 Lcontent<script type="math/tex" id="MathJax-Element-5049"> L_{content} </script> 的时候,用到了 VGG-19 的 conv4_2 层,而在建立 Lstyle<script type="math/tex" id="MathJax-Element-5050">L_{style}</script> 的时候,用到了VGG-19 的 conv1_1, conv2_1, conv3_1, conv4_1 以及 conv5_1。

下一篇博客里,我们将介绍基于 TensorFlow 的代码实现。

<script type="text/javascript"> $(function () { $(‘pre.prettyprint code‘).each(function () { var lines = $(this).text().split(‘\n‘).length; var $numbering = $(‘
    ‘).addClass(‘pre-numbering‘).hide(); $(this).addClass(‘has-numbering‘).parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($(‘
  • ‘).text(i)); }; $numbering.fadeIn(1700); }); }); </script>

    机器学习:利用卷积神经网络实现图像风格迁移 (一)