首页 > 代码库 > [DL] CNN源码分析
[DL] CNN源码分析
在Hinton的教程中, 使用Python的theano库搭建的CNN是其中重要一环, 而其中的所谓的SGD - stochastic gradient descend算法又是如何实现的呢? 看下面源码(篇幅考虑只取测试模型函数, 训练函数只是多了一个updates参数):
3 classifier = LogisticRegression(input=x, n_in=24 * 48, n_out=32) 7 cost = classifier.negative_log_likelihood(y)11 test_model = theano.function(inputs=[index],12 outputs=classifier.errors(y),13 givens={14 x: test_set_x[index * batch_size: (index + 1) * batch_size],15 y: test_set_y[index * batch_size: (index + 1) * batch_size]})
行3声明了一个对象classifer, 它的输入是符号x, 大小为24*48, 输出长度为32.
行11定义了一个theano的函数对象, 接收的是下标index, 使用输入数据的第index*batch_size~第(index+1)*batch_size个数据作为函数的输入, 输出为误差.
我们再来看看行12中的errors函数的定义:
def errors(self, y): # check if y has same dimension of y_pred if y.ndim != self.y_pred.ndim: raise TypeError(‘y should have the same shape as self.y_pred‘, (‘y‘, target.type, ‘y_pred‘, self.y_pred.type)) # check if y is of the correct datatype if y.dtype.startswith(‘int‘): # the T.neq operator returns a vector of 0s and 1s, where 1 # represents a mistake in prediction return T.mean(T.neq(self.y_pred, y)) else: raise NotImplementedError()
self.y_pred 是一个大小为batch_size的向量, 每个元素代表batch_size中对应输入的网络判断结果, errors函数接受1个同等大小的期望输出y, 将两者进行比较求差后作均值返回, 这正是误差的定义.
那么问题来了, 这个 self.y_pred 是如何计算的? 这里我们看LogisticRegression的构造函数:
1 def __init__(self, input, n_in, n_out): 2 3 # initialize with 0 the weights W as a matrix of shape (n_in, n_out) 4 self.W = theano.shared(value=http://www.mamicode.com/numpy.zeros((n_in, n_out), 5 dtype=theano.config.floatX), 6 name=‘W‘, borrow=True) 7 # initialize the baises b as a vector of n_out 0s 8 self.b = theano.shared(value=http://www.mamicode.com/numpy.zeros((n_out,), 9 dtype=theano.config.floatX),10 name=‘b‘, borrow=True)11 12 # compute vector of class-membership probabilities in symbolic form13 self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b)14 15 # compute prediction as class whose probability is maximal in16 # symbolic form17 self.y_pred = T.argmax(self.p_y_given_x, axis=1)18 19 # parameters of the model20 self.params = [self.W, self.b]
[DL] CNN源码分析
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。