首页 > 代码库 > 使用caffe提供的python接口训练mnist例子

使用caffe提供的python接口训练mnist例子

1 首先肯定是安装caffe,并且编译python接口,如果是在windows上,最好把编译出来的python文件夹的caffe文件夹拷贝到anaconda文件夹下面去,这样就有代码自动提示功能,如下:

技术分享

本文中使用的ide为anaconda安装中自带的spyder,如图所示,将根目录设置为caffe的根目录。

技术分享

import caffe
caffe.set_mode_cpu()
solver = caffe.SGDSolver(examples/mnist/lenet_solver.prototxt)
solver.solve()

以上为一次全部迭代,如果想自己控制,可使用如下代码:

import caffe
caffe.set_mode_cpu()
solver = caffe.SGDSolver(examples/mnist/lenet_solver.prototxt)
#solver.solve()

iter = solver.iter
while iter<10000:
    solver.step(1)
    iter = solver.iter
    input_data = solver.net.blobs[data].data  
    loss = solver.net.blobs[loss].data
    accuracy = solver.test_nets[0].blobs[accuracy].data
    print iter:, iter, loss:, loss,accuracy:,accuracy
import caffe
import matplotlib.pyplot as plt     
import numpy as np

def vis_square(data):
    """Take an array of shape (n, height, width) or (n, height, width, 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
    
    # normalize data for display
    data = http://www.mamicode.com/(data - data.min()) / (data.max() - data.min())
    
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
               (0, 1), (0, 1))                 # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don‘t pad the last dimension (if there is one)
    data = http://www.mamicode.com/np.pad(data, padding, mode=constant, constant_values=1)  # pad with ones (white)
    
    # tile the filters into an image
    data = http://www.mamicode.com/data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    if data.shape[2] == 1:
        data = data[:,:,0]
    plt.imshow(data); plt.axis(off)

if __name__ == __main__:
    caffe.set_mode_cpu()
    solver = caffe.SGDSolver(examples/mnist/lenet_solver.prototxt)
    solver.step(1)
    input_data = solver.net.blobs[data].data  
    plt.figure(0)
    vis_square(input_data.transpose(0, 2, 3, 1))  
    filters = solver.net.params[conv1][0].data
    plt.figure(1)
    vis_square(filters.transpose(0, 2, 3, 1))

    特征图:
技术分享   
    权值图

技术分享

使用caffe提供的python接口训练mnist例子