首页 > 代码库 > (原)torch的apply函数
(原)torch的apply函数
转载请注明出处:
http://www.cnblogs.com/darkknightzh/p/6221633.html
torch中的apply函数通过可以不断遍历model的各个模块。实际上其使用的是深度优先算法。
其具体代码如下所示(代码见torch/install/share/lua/5.1/nn/Module.lua):
-- Run a callback (called with the module as an argument) in preorder over this -- module and its children. -- function Module:apply(callback) callback(self) if self.modules then for _, module in ipairs(self.modules) do module:apply(callback) end end end
可见,apply递归调用自身,直到不存在模块为止(这样说不太合理)。
如下所示的测试代码:
require "dpnn" function createModel() local net = nn.Sequential() net:add(nn.SpatialConvolutionMM(3, 64, 7, 7, 2, 2, 3, 3)) net:add(nn.SpatialBatchNormalization(64)) net:add(nn.ReLU()) net:add(nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1)) net:add(nn.Inception{ inputSize = 192, kernelSize = {3, 5}, kernelStride = {1, 1}, outputSize = {128, 32}, reduceSize = {96, 16, 32, 64}, pool = nn.SpatialMaxPooling(3, 3, 1, 1, 1, 1), batchNorm = true }) net:add(nn.Inception{ inputSize = 256, kernelSize = {3, 5}, kernelStride = {1, 1}, outputSize = {128, 64}, reduceSize = {96, 32, 64, 64}, pool = nn.SpatialLPPooling(256, 2, 3, 3, 1, 1), batchNorm = false }) net:add(nn.SpatialAveragePooling(7, 7)) net:add(nn.View(320)) net:add(nn.Linear(320, 128)) net:add(nn.Normalize(2)) return net end torch.setdefaulttensortype(‘torch.FloatTensor‘) local model = createModel() --print(model) tt = 0 model:apply(function(module) tt = tt + 1 print(tt, module) end)
其输出结果为:
1 nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> output] (1): nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3) (2): nn.SpatialBatchNormalization (3): nn.ReLU (4): nn.SpatialMaxPooling(3x3, 2,2, 1,1) (5): nn.Inception @ nn.DepthConcat { input |`-> (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] | (1): nn.SpatialConvolution(192 -> 96, 1x1) | (2): nn.SpatialBatchNormalization | (3): nn.ReLU | (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) | (5): nn.SpatialBatchNormalization | (6): nn.ReLU | } |`-> (2): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] | (1): nn.SpatialConvolution(192 -> 16, 1x1) | (2): nn.SpatialBatchNormalization | (3): nn.ReLU | (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2) | (5): nn.SpatialBatchNormalization | (6): nn.ReLU | } |`-> (3): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1) | (2): nn.SpatialConvolution(192 -> 32, 1x1) | (3): nn.SpatialBatchNormalization | (4): nn.ReLU | } |`-> (4): nn.Sequential { [input -> (1) -> (2) -> (3) -> output] (1): nn.SpatialConvolution(192 -> 64, 1x1) (2): nn.SpatialBatchNormalization (3): nn.ReLU } ... -> output } (6): nn.Inception @ nn.DepthConcat { input |`-> (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialConvolution(256 -> 96, 1x1) | (2): nn.ReLU | (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) | (4): nn.ReLU | } |`-> (2): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialConvolution(256 -> 32, 1x1) | (2): nn.ReLU | (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2) | (4): nn.ReLU | } |`-> (3): nn.Sequential { | [input -> (1) -> (2) -> (3) -> output] | (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.Square | (2): nn.SpatialAveragePooling(3x3, 1,1) | (3): nn.MulConstant | (4): nn.Sqrt | } | (2): nn.SpatialConvolution(256 -> 64, 1x1) | (3): nn.ReLU | } |`-> (4): nn.Sequential { [input -> (1) -> (2) -> output] (1): nn.SpatialConvolution(256 -> 64, 1x1) (2): nn.ReLU } ... -> output } (7): nn.SpatialAveragePooling(7x7, 1,1) (8): nn.View(320) (9): nn.Linear(320 -> 128) (10): nn.Normalize(2) } 2 nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3) 3 nn.SpatialBatchNormalization 4 nn.ReLU 5 nn.SpatialMaxPooling(3x3, 2,2, 1,1) 6 nn.Inception @ nn.DepthConcat { input |`-> (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] | (1): nn.SpatialConvolution(192 -> 96, 1x1) | (2): nn.SpatialBatchNormalization | (3): nn.ReLU | (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) | (5): nn.SpatialBatchNormalization | (6): nn.ReLU | } |`-> (2): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] | (1): nn.SpatialConvolution(192 -> 16, 1x1) | (2): nn.SpatialBatchNormalization | (3): nn.ReLU | (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2) | (5): nn.SpatialBatchNormalization | (6): nn.ReLU | } |`-> (3): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1) | (2): nn.SpatialConvolution(192 -> 32, 1x1) | (3): nn.SpatialBatchNormalization | (4): nn.ReLU | } |`-> (4): nn.Sequential { [input -> (1) -> (2) -> (3) -> output] (1): nn.SpatialConvolution(192 -> 64, 1x1) (2): nn.SpatialBatchNormalization (3): nn.ReLU } ... -> output } 7 nn.DepthConcat { input |`-> (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] | (1): nn.SpatialConvolution(192 -> 96, 1x1) | (2): nn.SpatialBatchNormalization | (3): nn.ReLU | (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) | (5): nn.SpatialBatchNormalization | (6): nn.ReLU | } |`-> (2): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] | (1): nn.SpatialConvolution(192 -> 16, 1x1) | (2): nn.SpatialBatchNormalization | (3): nn.ReLU | (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2) | (5): nn.SpatialBatchNormalization | (6): nn.ReLU | } |`-> (3): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1) | (2): nn.SpatialConvolution(192 -> 32, 1x1) | (3): nn.SpatialBatchNormalization | (4): nn.ReLU | } |`-> (4): nn.Sequential { [input -> (1) -> (2) -> (3) -> output] (1): nn.SpatialConvolution(192 -> 64, 1x1) (2): nn.SpatialBatchNormalization (3): nn.ReLU } ... -> output } 8 nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] (1): nn.SpatialConvolution(192 -> 96, 1x1) (2): nn.SpatialBatchNormalization (3): nn.ReLU (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) (5): nn.SpatialBatchNormalization (6): nn.ReLU } 9 nn.SpatialConvolution(192 -> 96, 1x1) 10 nn.SpatialBatchNormalization 11 nn.ReLU 12 nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) 13 nn.SpatialBatchNormalization 14 nn.ReLU 15 nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output] (1): nn.SpatialConvolution(192 -> 16, 1x1) (2): nn.SpatialBatchNormalization (3): nn.ReLU (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2) (5): nn.SpatialBatchNormalization (6): nn.ReLU } 16 nn.SpatialConvolution(192 -> 16, 1x1) 17 nn.SpatialBatchNormalization 18 nn.ReLU 19 nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2) 20 nn.SpatialBatchNormalization 21 nn.ReLU 22 nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> output] (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1) (2): nn.SpatialConvolution(192 -> 32, 1x1) (3): nn.SpatialBatchNormalization (4): nn.ReLU } 23 nn.SpatialMaxPooling(3x3, 1,1, 1,1) 24 nn.SpatialConvolution(192 -> 32, 1x1) 25 nn.SpatialBatchNormalization 26 nn.ReLU 27 nn.Sequential { [input -> (1) -> (2) -> (3) -> output] (1): nn.SpatialConvolution(192 -> 64, 1x1) (2): nn.SpatialBatchNormalization (3): nn.ReLU } 28 nn.SpatialConvolution(192 -> 64, 1x1) 29 nn.SpatialBatchNormalization 30 nn.ReLU 31 nn.Inception @ nn.DepthConcat { input |`-> (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialConvolution(256 -> 96, 1x1) | (2): nn.ReLU | (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) | (4): nn.ReLU | } |`-> (2): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialConvolution(256 -> 32, 1x1) | (2): nn.ReLU | (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2) | (4): nn.ReLU | } |`-> (3): nn.Sequential { | [input -> (1) -> (2) -> (3) -> output] | (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.Square | (2): nn.SpatialAveragePooling(3x3, 1,1) | (3): nn.MulConstant | (4): nn.Sqrt | } | (2): nn.SpatialConvolution(256 -> 64, 1x1) | (3): nn.ReLU | } |`-> (4): nn.Sequential { [input -> (1) -> (2) -> output] (1): nn.SpatialConvolution(256 -> 64, 1x1) (2): nn.ReLU } ... -> output } 32 nn.DepthConcat { input |`-> (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialConvolution(256 -> 96, 1x1) | (2): nn.ReLU | (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) | (4): nn.ReLU | } |`-> (2): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.SpatialConvolution(256 -> 32, 1x1) | (2): nn.ReLU | (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2) | (4): nn.ReLU | } |`-> (3): nn.Sequential { | [input -> (1) -> (2) -> (3) -> output] | (1): nn.Sequential { | [input -> (1) -> (2) -> (3) -> (4) -> output] | (1): nn.Square | (2): nn.SpatialAveragePooling(3x3, 1,1) | (3): nn.MulConstant | (4): nn.Sqrt | } | (2): nn.SpatialConvolution(256 -> 64, 1x1) | (3): nn.ReLU | } |`-> (4): nn.Sequential { [input -> (1) -> (2) -> output] (1): nn.SpatialConvolution(256 -> 64, 1x1) (2): nn.ReLU } ... -> output } 33 nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> output] (1): nn.SpatialConvolution(256 -> 96, 1x1) (2): nn.ReLU (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) (4): nn.ReLU } 34 nn.SpatialConvolution(256 -> 96, 1x1) 35 nn.ReLU 36 nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1) 37 nn.ReLU 38 nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> output] (1): nn.SpatialConvolution(256 -> 32, 1x1) (2): nn.ReLU (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2) (4): nn.ReLU } 39 nn.SpatialConvolution(256 -> 32, 1x1) 40 nn.ReLU 41 nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2) 42 nn.ReLU 43 nn.Sequential { [input -> (1) -> (2) -> (3) -> output] (1): nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> output] (1): nn.Square (2): nn.SpatialAveragePooling(3x3, 1,1) (3): nn.MulConstant (4): nn.Sqrt } (2): nn.SpatialConvolution(256 -> 64, 1x1) (3): nn.ReLU } 44 nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> output] (1): nn.Square (2): nn.SpatialAveragePooling(3x3, 1,1) (3): nn.MulConstant (4): nn.Sqrt } 45 nn.Square 46 nn.SpatialAveragePooling(3x3, 1,1) 47 nn.MulConstant 48 nn.Sqrt 49 nn.SpatialConvolution(256 -> 64, 1x1) 50 nn.ReLU 51 nn.Sequential { [input -> (1) -> (2) -> output] (1): nn.SpatialConvolution(256 -> 64, 1x1) (2): nn.ReLU } 52 nn.SpatialConvolution(256 -> 64, 1x1) 53 nn.ReLU 54 nn.SpatialAveragePooling(7x7, 1,1) 55 nn.View(320) 56 nn.Linear(320 -> 128) 57 nn.Normalize(2)
由上述结果可以看出,使用apply后,第1次输出整个模型,此处为最顶层的。
第2-5次输出:
2 nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3)
3 nn.SpatialBatchNormalization
4 nn.ReLU
5 nn.SpatialMaxPooling(3x3, 2,2, 1,1)
为Inception之前的几个层。
第6次为nn.Inception @ nn.DepthConcat,第7次为nn.DepthConcat。此处是第一个Inceptioin层。
第8次为Inception的第一个nn.Sequential,第9-14次为该层的具体层。此时已经到了第一个最底层。
第15次为Inception的第二个nn.Sequential,第16-21次为该层的具体层。此时已经到了第二个最底层。
第22次为Inception的第三个nn.Sequential,第23-26次为该层的具体层。此时已经到了第三个最底层。
第27次为Inception的第四个nn.Sequential,第28-30次为该层的具体层。此时已经到了第四个最底层。
至此,第一个Inception层通过深度优先的方式遍历完毕。
第31次为nn.Inception @ nn.DepthConcat,第32次为nn.DepthConcat。此处是第二个Inceptioin层(注意,为了区分第一个Inception和第二个Inception层,这两个层具体结构不完全一样)。
第33次为Inception的第一个nn.Sequential,第34-37次为该层的具体层。此时已经到了第一个最底层。
第38次为Inception的第二个nn.Sequential,第39-42次为该层的具体层。此时已经到了第二个最底层。
第43次为Inception的第三个nn.Sequential。
第44次为第三个nn.Sequential的第一个小module(也是一个nn.Sequential)。第45-48依次遍历此nn.Sequential。到了最底层后遍历完毕。
第49-50为第三个nn.Sequential的最后两层。
第51次为Inception的第四个nn.Sequential,第52-53次为该层的具体层。此时已经到了第四个最底层。
至此,第二个Inception层通过深度优先的方式遍历完毕。
第54-57为最后的两个层。
由上面可以看出,apply采用的是深度优先的方式进行遍历。
(原)torch的apply函数