首页 > 代码库 > (原)torch的apply函数

(原)torch的apply函数

转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/6221633.html

 

torch中的apply函数通过可以不断遍历model的各个模块。实际上其使用的是深度优先算法。

其具体代码如下所示(代码见torch/install/share/lua/5.1/nn/Module.lua):

-- Run a callback (called with the module as an argument) in preorder over this
-- module and its children.
--
function Module:apply(callback)
    callback(self)

    if self.modules then
        for _, module in ipairs(self.modules) do
            module:apply(callback)
        end
    end
end

可见,apply递归调用自身,直到不存在模块为止(这样说不太合理)。

 

如下所示的测试代码:

require "dpnn"

function createModel()
   local net = nn.Sequential()

   net:add(nn.SpatialConvolutionMM(3, 64, 7, 7, 2, 2, 3, 3))
   net:add(nn.SpatialBatchNormalization(64))
   net:add(nn.ReLU())
   net:add(nn.SpatialMaxPooling(3, 3, 2, 2, 1, 1))

   net:add(nn.Inception{
     inputSize = 192,
     kernelSize = {3, 5},
     kernelStride = {1, 1},
     outputSize = {128, 32},
     reduceSize = {96, 16, 32, 64},
     pool = nn.SpatialMaxPooling(3, 3, 1, 1, 1, 1),
     batchNorm = true
   })

   net:add(nn.Inception{
     inputSize = 256,
     kernelSize = {3, 5},
     kernelStride = {1, 1},
     outputSize = {128, 64},
     reduceSize = {96, 32, 64, 64},
     pool = nn.SpatialLPPooling(256, 2, 3, 3, 1, 1),
     batchNorm = false
   })

   net:add(nn.SpatialAveragePooling(7, 7))
   net:add(nn.View(320))
   net:add(nn.Linear(320, 128))
   net:add(nn.Normalize(2))

   return net
end


torch.setdefaulttensortype(torch.FloatTensor)

local model = createModel()

--print(model)
tt = 0
model:apply(function(module)
    tt = tt + 1
    print(tt, module)
end)

其输出结果为:

技术分享
1	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> output]
  (1): nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3)
  (2): nn.SpatialBatchNormalization
  (3): nn.ReLU
  (4): nn.SpatialMaxPooling(3x3, 2,2, 1,1)
  (5): nn.Inception @ nn.DepthConcat {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
      |      (1): nn.SpatialConvolution(192 -> 96, 1x1)
      |      (2): nn.SpatialBatchNormalization
      |      (3): nn.ReLU
      |      (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
      |      (5): nn.SpatialBatchNormalization
      |      (6): nn.ReLU
      |    }
      |`-> (2): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
      |      (1): nn.SpatialConvolution(192 -> 16, 1x1)
      |      (2): nn.SpatialBatchNormalization
      |      (3): nn.ReLU
      |      (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
      |      (5): nn.SpatialBatchNormalization
      |      (6): nn.ReLU
      |    }
      |`-> (3): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> output]
      |      (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
      |      (2): nn.SpatialConvolution(192 -> 32, 1x1)
      |      (3): nn.SpatialBatchNormalization
      |      (4): nn.ReLU
      |    }
      |`-> (4): nn.Sequential {
             [input -> (1) -> (2) -> (3) -> output]
             (1): nn.SpatialConvolution(192 -> 64, 1x1)
             (2): nn.SpatialBatchNormalization
             (3): nn.ReLU
           }
       ... -> output
  }
  (6): nn.Inception @ nn.DepthConcat {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> output]
      |      (1): nn.SpatialConvolution(256 -> 96, 1x1)
      |      (2): nn.ReLU
      |      (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
      |      (4): nn.ReLU
      |    }
      |`-> (2): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> output]
      |      (1): nn.SpatialConvolution(256 -> 32, 1x1)
      |      (2): nn.ReLU
      |      (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
      |      (4): nn.ReLU
      |    }
      |`-> (3): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> output]
      |      (1): nn.Sequential {
      |        [input -> (1) -> (2) -> (3) -> (4) -> output]
      |        (1): nn.Square
      |        (2): nn.SpatialAveragePooling(3x3, 1,1)
      |        (3): nn.MulConstant
      |        (4): nn.Sqrt
      |      }
      |      (2): nn.SpatialConvolution(256 -> 64, 1x1)
      |      (3): nn.ReLU
      |    }
      |`-> (4): nn.Sequential {
             [input -> (1) -> (2) -> output]
             (1): nn.SpatialConvolution(256 -> 64, 1x1)
             (2): nn.ReLU
           }
       ... -> output
  }
  (7): nn.SpatialAveragePooling(7x7, 1,1)
  (8): nn.View(320)
  (9): nn.Linear(320 -> 128)
  (10): nn.Normalize(2)
}
2	nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3)
3	nn.SpatialBatchNormalization
4	nn.ReLU
5	nn.SpatialMaxPooling(3x3, 2,2, 1,1)
6	nn.Inception @ nn.DepthConcat {
  input
    |`-> (1): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
    |      (1): nn.SpatialConvolution(192 -> 96, 1x1)
    |      (2): nn.SpatialBatchNormalization
    |      (3): nn.ReLU
    |      (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
    |      (5): nn.SpatialBatchNormalization
    |      (6): nn.ReLU
    |    }
    |`-> (2): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
    |      (1): nn.SpatialConvolution(192 -> 16, 1x1)
    |      (2): nn.SpatialBatchNormalization
    |      (3): nn.ReLU
    |      (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
    |      (5): nn.SpatialBatchNormalization
    |      (6): nn.ReLU
    |    }
    |`-> (3): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> output]
    |      (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
    |      (2): nn.SpatialConvolution(192 -> 32, 1x1)
    |      (3): nn.SpatialBatchNormalization
    |      (4): nn.ReLU
    |    }
    |`-> (4): nn.Sequential {
           [input -> (1) -> (2) -> (3) -> output]
           (1): nn.SpatialConvolution(192 -> 64, 1x1)
           (2): nn.SpatialBatchNormalization
           (3): nn.ReLU
         }
     ... -> output
}
7	nn.DepthConcat {
  input
    |`-> (1): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
    |      (1): nn.SpatialConvolution(192 -> 96, 1x1)
    |      (2): nn.SpatialBatchNormalization
    |      (3): nn.ReLU
    |      (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
    |      (5): nn.SpatialBatchNormalization
    |      (6): nn.ReLU
    |    }
    |`-> (2): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
    |      (1): nn.SpatialConvolution(192 -> 16, 1x1)
    |      (2): nn.SpatialBatchNormalization
    |      (3): nn.ReLU
    |      (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
    |      (5): nn.SpatialBatchNormalization
    |      (6): nn.ReLU
    |    }
    |`-> (3): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> output]
    |      (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
    |      (2): nn.SpatialConvolution(192 -> 32, 1x1)
    |      (3): nn.SpatialBatchNormalization
    |      (4): nn.ReLU
    |    }
    |`-> (4): nn.Sequential {
           [input -> (1) -> (2) -> (3) -> output]
           (1): nn.SpatialConvolution(192 -> 64, 1x1)
           (2): nn.SpatialBatchNormalization
           (3): nn.ReLU
         }
     ... -> output
}
8	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
  (1): nn.SpatialConvolution(192 -> 96, 1x1)
  (2): nn.SpatialBatchNormalization
  (3): nn.ReLU
  (4): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
  (5): nn.SpatialBatchNormalization
  (6): nn.ReLU
}
9	nn.SpatialConvolution(192 -> 96, 1x1)
10	nn.SpatialBatchNormalization
11	nn.ReLU
12	nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
13	nn.SpatialBatchNormalization
14	nn.ReLU
15	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
  (1): nn.SpatialConvolution(192 -> 16, 1x1)
  (2): nn.SpatialBatchNormalization
  (3): nn.ReLU
  (4): nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
  (5): nn.SpatialBatchNormalization
  (6): nn.ReLU
}
16	nn.SpatialConvolution(192 -> 16, 1x1)
17	nn.SpatialBatchNormalization
18	nn.ReLU
19	nn.SpatialConvolution(16 -> 32, 5x5, 1,1, 2,2)
20	nn.SpatialBatchNormalization
21	nn.ReLU
22	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> output]
  (1): nn.SpatialMaxPooling(3x3, 1,1, 1,1)
  (2): nn.SpatialConvolution(192 -> 32, 1x1)
  (3): nn.SpatialBatchNormalization
  (4): nn.ReLU
}
23	nn.SpatialMaxPooling(3x3, 1,1, 1,1)
24	nn.SpatialConvolution(192 -> 32, 1x1)
25	nn.SpatialBatchNormalization
26	nn.ReLU
27	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> output]
  (1): nn.SpatialConvolution(192 -> 64, 1x1)
  (2): nn.SpatialBatchNormalization
  (3): nn.ReLU
}
28	nn.SpatialConvolution(192 -> 64, 1x1)
29	nn.SpatialBatchNormalization
30	nn.ReLU
31	nn.Inception @ nn.DepthConcat {
  input
    |`-> (1): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> output]
    |      (1): nn.SpatialConvolution(256 -> 96, 1x1)
    |      (2): nn.ReLU
    |      (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
    |      (4): nn.ReLU
    |    }
    |`-> (2): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> output]
    |      (1): nn.SpatialConvolution(256 -> 32, 1x1)
    |      (2): nn.ReLU
    |      (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
    |      (4): nn.ReLU
    |    }
    |`-> (3): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> output]
    |      (1): nn.Sequential {
    |        [input -> (1) -> (2) -> (3) -> (4) -> output]
    |        (1): nn.Square
    |        (2): nn.SpatialAveragePooling(3x3, 1,1)
    |        (3): nn.MulConstant
    |        (4): nn.Sqrt
    |      }
    |      (2): nn.SpatialConvolution(256 -> 64, 1x1)
    |      (3): nn.ReLU
    |    }
    |`-> (4): nn.Sequential {
           [input -> (1) -> (2) -> output]
           (1): nn.SpatialConvolution(256 -> 64, 1x1)
           (2): nn.ReLU
         }
     ... -> output
}
32	nn.DepthConcat {
  input
    |`-> (1): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> output]
    |      (1): nn.SpatialConvolution(256 -> 96, 1x1)
    |      (2): nn.ReLU
    |      (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
    |      (4): nn.ReLU
    |    }
    |`-> (2): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> (4) -> output]
    |      (1): nn.SpatialConvolution(256 -> 32, 1x1)
    |      (2): nn.ReLU
    |      (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
    |      (4): nn.ReLU
    |    }
    |`-> (3): nn.Sequential {
    |      [input -> (1) -> (2) -> (3) -> output]
    |      (1): nn.Sequential {
    |        [input -> (1) -> (2) -> (3) -> (4) -> output]
    |        (1): nn.Square
    |        (2): nn.SpatialAveragePooling(3x3, 1,1)
    |        (3): nn.MulConstant
    |        (4): nn.Sqrt
    |      }
    |      (2): nn.SpatialConvolution(256 -> 64, 1x1)
    |      (3): nn.ReLU
    |    }
    |`-> (4): nn.Sequential {
           [input -> (1) -> (2) -> output]
           (1): nn.SpatialConvolution(256 -> 64, 1x1)
           (2): nn.ReLU
         }
     ... -> output
}
33	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> output]
  (1): nn.SpatialConvolution(256 -> 96, 1x1)
  (2): nn.ReLU
  (3): nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
  (4): nn.ReLU
}
34	nn.SpatialConvolution(256 -> 96, 1x1)
35	nn.ReLU
36	nn.SpatialConvolution(96 -> 128, 3x3, 1,1, 1,1)
37	nn.ReLU
38	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> output]
  (1): nn.SpatialConvolution(256 -> 32, 1x1)
  (2): nn.ReLU
  (3): nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
  (4): nn.ReLU
}
39	nn.SpatialConvolution(256 -> 32, 1x1)
40	nn.ReLU
41	nn.SpatialConvolution(32 -> 64, 5x5, 1,1, 2,2)
42	nn.ReLU
43	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> output]
  (1): nn.Sequential {
    [input -> (1) -> (2) -> (3) -> (4) -> output]
    (1): nn.Square
    (2): nn.SpatialAveragePooling(3x3, 1,1)
    (3): nn.MulConstant
    (4): nn.Sqrt
  }
  (2): nn.SpatialConvolution(256 -> 64, 1x1)
  (3): nn.ReLU
}
44	nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> output]
  (1): nn.Square
  (2): nn.SpatialAveragePooling(3x3, 1,1)
  (3): nn.MulConstant
  (4): nn.Sqrt
}
45	nn.Square
46	nn.SpatialAveragePooling(3x3, 1,1)
47	nn.MulConstant
48	nn.Sqrt
49	nn.SpatialConvolution(256 -> 64, 1x1)
50	nn.ReLU
51	nn.Sequential {
  [input -> (1) -> (2) -> output]
  (1): nn.SpatialConvolution(256 -> 64, 1x1)
  (2): nn.ReLU
}
52	nn.SpatialConvolution(256 -> 64, 1x1)
53	nn.ReLU
54	nn.SpatialAveragePooling(7x7, 1,1)
55	nn.View(320)
56	nn.Linear(320 -> 128)
57	nn.Normalize(2)
View Code

由上述结果可以看出,使用apply后,第1次输出整个模型,此处为最顶层的。

 

第2-5次输出:

2       nn.SpatialConvolutionMM(3 -> 64, 7x7, 2,2, 3,3)

3       nn.SpatialBatchNormalization

4       nn.ReLU

5       nn.SpatialMaxPooling(3x3, 2,2, 1,1)

为Inception之前的几个层。

 

第6次为nn.Inception @ nn.DepthConcat,第7次为nn.DepthConcat。此处是第一个Inceptioin层。

第8次为Inception的第一个nn.Sequential,第9-14次为该层的具体层。此时已经到了第一个最底层。

第15次为Inception的第二个nn.Sequential,第16-21次为该层的具体层。此时已经到了第二个最底层。

第22次为Inception的第三个nn.Sequential,第23-26次为该层的具体层。此时已经到了第三个最底层。

第27次为Inception的第四个nn.Sequential,第28-30次为该层的具体层。此时已经到了第四个最底层。

至此,第一个Inception层通过深度优先的方式遍历完毕。

 

第31次为nn.Inception @ nn.DepthConcat,第32次为nn.DepthConcat。此处是第二个Inceptioin层(注意,为了区分第一个Inception和第二个Inception层,这两个层具体结构不完全一样)。

第33次为Inception的第一个nn.Sequential,第34-37次为该层的具体层。此时已经到了第一个最底层。

第38次为Inception的第二个nn.Sequential,第39-42次为该层的具体层。此时已经到了第二个最底层。

第43次为Inception的第三个nn.Sequential。

第44次为第三个nn.Sequential的第一个小module(也是一个nn.Sequential)。第45-48依次遍历此nn.Sequential。到了最底层后遍历完毕。

第49-50为第三个nn.Sequential的最后两层。

第51次为Inception的第四个nn.Sequential,第52-53次为该层的具体层。此时已经到了第四个最底层。

至此,第二个Inception层通过深度优先的方式遍历完毕。

 

第54-57为最后的两个层。

 

由上面可以看出,apply采用的是深度优先的方式进行遍历。

(原)torch的apply函数