首页 > 代码库 > mxnet实战系列(一)入门与跑mnist数据集
mxnet实战系列(一)入门与跑mnist数据集
最近在摸mxnet和tensorflow。两个我都搭起来了。tensorflow跑了不少代码,总的来说用得比较顺畅,文档很丰富,api熟悉熟悉写代码没什么问题。
今天把两个平台做了一下对比。同是跑mnist,tensorflow 要比mxnet 慢一二十倍。mxnet只需要半分钟,tensorflow跑了13分钟。
在mxnet中如何开跑?
cd /mxnet/example/image-classification python train_mnist.py
我用的是最新的mxnet版本。运行脚本它会自动下载数据集。
然后刷刷刷的刷屏了。
我们来看看这个脚本如何写的,从而建立mxnet编程思路:
import find_mxnet
import mxnet as mx
import argparse
import os, sys
import train_model
def _download(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(‘train-images-idx3-ubyte‘)) or \
(not os.path.exists(‘train-labels-idx1-ubyte‘)) or \
(not os.path.exists(‘t10k-images-idx3-ubyte‘)) or \
(not os.path.exists(‘t10k-labels-idx1-ubyte‘)):
os.system("wget http://data.dmlc.ml/mxnet/data/mnist.zip")
os.system("unzip -u mnist.zip; rm mnist.zip")
os.chdir("..")
def get_loc(data, attr={‘lr_mult‘:‘0.01‘}):
"""
the localisation network in lenet-stn, it will increase acc about more than 1%,
when num-epoch >=15
"""
loc = mx.symbol.Convolution(data=http://www.mamicode.com/data, num_filter=30, kernel=(5, 5), stride=(2,2))
loc = mx.symbol.Activation(data = http://www.mamicode.com/loc, act_type=‘relu‘)
loc = mx.symbol.Pooling(data=http://www.mamicode.com/loc, kernel=(2, 2), stride=(2, 2), pool_type=‘max‘)
loc = mx.symbol.Convolution(data=http://www.mamicode.com/loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1))
loc = mx.symbol.Activation(data = http://www.mamicode.com/loc, act_type=‘relu‘)
loc = mx.symbol.Pooling(data=http://www.mamicode.com/loc, global_pool=True, kernel=(2, 2), pool_type=‘avg‘)
loc = mx.symbol.Flatten(data=http://www.mamicode.com/loc)
loc = mx.symbol.FullyConnected(data=http://www.mamicode.com/loc, num_hidden=6, name="stn_loc", attr=attr)
return loc
def get_mlp():
"""
multi-layer perceptron
"""
data = http://www.mamicode.com/mx.symbol.Variable(‘data‘)
fc1 = mx.symbol.FullyConnected(data = http://www.mamicode.com/data, name=‘fc1‘, num_hidden=128)
act1 = mx.symbol.Activation(data = http://www.mamicode.com/fc1, name=‘relu1‘, act_type="relu")
fc2 = mx.symbol.FullyConnected(data = http://www.mamicode.com/act1, name = ‘fc2‘, num_hidden = 64)
act2 = mx.symbol.Activation(data = http://www.mamicode.com/fc2, name=‘relu2‘, act_type="relu")
fc3 = mx.symbol.FullyConnected(data = http://www.mamicode.com/act2, name=‘fc3‘, num_hidden=10)
mlp = mx.symbol.SoftmaxOutput(data = http://www.mamicode.com/fc3, name = ‘softmax‘)
return mlp
def get_lenet(add_stn=False):
"""
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
Haffner. "Gradient-based learning applied to document recognition."
Proceedings of the IEEE (1998)
"""
data = http://www.mamicode.com/mx.symbol.Variable(‘data‘)
if(add_stn):
data = http://www.mamicode.com/mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),
transform_type="affine", sampler_type="bilinear")
# first conv
conv1 = mx.symbol.Convolution(data=http://www.mamicode.com/data, kernel=(5,5), num_filter=20)
tanh1 = mx.symbol.Activation(data=http://www.mamicode.com/conv1, act_type="tanh")
pool1 = mx.symbol.Pooling(data=http://www.mamicode.com/tanh1, pool_type="max",
kernel=(2,2), stride=(2,2))
# second conv
conv2 = mx.symbol.Convolution(data=http://www.mamicode.com/pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.symbol.Activation(data=http://www.mamicode.com/conv2, act_type="tanh")
pool2 = mx.symbol.Pooling(data=http://www.mamicode.com/tanh2, pool_type="max",
kernel=(2,2), stride=(2,2))
# first fullc
flatten = mx.symbol.Flatten(data=http://www.mamicode.com/pool2)
fc1 = mx.symbol.FullyConnected(data=http://www.mamicode.com/flatten, num_hidden=500)
tanh3 = mx.symbol.Activation(data=http://www.mamicode.com/fc1, act_type="tanh")
# second fullc
fc2 = mx.symbol.FullyConnected(data=http://www.mamicode.com/tanh3, num_hidden=10)
# loss
lenet = mx.symbol.SoftmaxOutput(data=http://www.mamicode.com/fc2, name=‘softmax‘)
return lenet
def get_iterator(data_shape):
def get_iterator_impl(args, kv):
data_dir = args.data_dir
if ‘://‘ not in args.data_dir:
_download(args.data_dir)
flat = False if len(data_shape) == 3 else True
train = mx.io.MNISTIter(
image = data_dir + "train-images-idx3-ubyte",
label = data_dir + "train-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = args.batch_size,
shuffle = True,
flat = flat,
num_parts = kv.num_workers,
part_index = kv.rank)
val = mx.io.MNISTIter(
image = data_dir + "t10k-images-idx3-ubyte",
label = data_dir + "t10k-labels-idx1-ubyte",
input_shape = data_shape,
batch_size = args.batch_size,
flat = flat,
num_parts = kv.num_workers,
part_index = kv.rank)
return (train, val)
return get_iterator_impl
def parse_args():
parser = argparse.ArgumentParser(description=‘train an image classifer on mnist‘)
parser.add_argument(‘--network‘, type=str, default=‘mlp‘,
choices = [‘mlp‘, ‘lenet‘, ‘lenet-stn‘],
help = ‘the cnn to use‘)
parser.add_argument(‘--data-dir‘, type=str, default=‘mnist/‘,
help=‘the input data directory‘)
parser.add_argument(‘--gpus‘, type=str,
help=‘the gpus will be used, e.g "0,1,2,3"‘)
parser.add_argument(‘--num-examples‘, type=int, default=60000,
help=‘the number of training examples‘)
parser.add_argument(‘--batch-size‘, type=int, default=128,
help=‘the batch size‘)
parser.add_argument(‘--lr‘, type=float, default=.1,
help=‘the initial learning rate‘)
parser.add_argument(‘--model-prefix‘, type=str,
help=‘the prefix of the model to load/save‘)
parser.add_argument(‘--save-model-prefix‘, type=str,
help=‘the prefix of the model to save‘)
parser.add_argument(‘--num-epochs‘, type=int, default=10,
help=‘the number of training epochs‘)
parser.add_argument(‘--load-epoch‘, type=int,
help="load the model on an epoch using the model-prefix")
parser.add_argument(‘--kv-store‘, type=str, default=‘local‘,
help=‘the kvstore type‘)
parser.add_argument(‘--lr-factor‘, type=float, default=1,
help=‘times the lr with a factor for every lr-factor-epoch epoch‘)
parser.add_argument(‘--lr-factor-epoch‘, type=float, default=1,
help=‘the number of epoch to factor the lr, could be .5‘)
return parser.parse_args()
if __name__ == ‘__main__‘:
args = parse_args()
if args.network == ‘mlp‘:
data_shape = (784, )
net = get_mlp()
elif args.network == ‘lenet-stn‘:
data_shape = (1, 28, 28)
net = get_lenet(True)
else:
data_shape = (1, 28, 28)
net = get_lenet()
# train
train_model.fit(args, net, get_iterator(data_shape))
先看Main函数,就是读配置参数,读网络结构,包括设置数据的大小,然后就是调用已有的包train_model。然后传入这之前设置的三个参数。就开始训练了。
编程架构也蛮清晰的。模块化也搞的好。
接着看看参数设置问题。参数导入了很多配置文件,基本上caffe中的Proto都在这个里面设置了。包括数据集地址,批大小,学习率,损失函数,等等。然后看看读网络结构,
读网络结构就是在一层一层的搭积木,根据之前读入的配置文件或者自己定义一些参数。搭好积木就开始训练了。
caffe的一个缺点是不够灵活,毕竟不是自己写代码,只是写配置文件,总感觉受制于人。mxnet和tensorflow就比较方便,提供api,你可以按你的方式来调用和定义
网络结构。总的说来,其实是后两个框架模块化做的好,提供底层的api支持你写自己的网络。caffe要自己写网络层的话还是很费劲的
mxnet实战系列(一)入门与跑mnist数据集
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。