首页 > 代码库 > 用java写bp神经网络(二)

用java写bp神经网络(二)

接上篇。

Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。

那么这两位老师儿长的什么样子,又是怎么做到的呢?

public interface Trainer {    public void train(Net net,DataProvider provider);}public interface Learner {    public void learn(Net net,TrainResult trainResult);}

 所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。

下面给出这两个接口的简单实现。

Trainer

Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。

public class CommonTrainer implements Trainer {	int ecophs;	Learner learner;	List<Double> costs = new ArrayList<>();	List<Double> accuracys = new ArrayList<>();	int batchSize = 1;	public CommonTrainer(int ecophs, Learner learner) {		super();		this.ecophs = ecophs;		this.learner = learner == null ? new MomentAdaptLearner() : learner;	}	public CommonTrainer(int ecophs, Learner learner, int batchSize) {		this(ecophs, learner);		this.batchSize = batchSize;	}	public void trainOne(final Net net, DataProvider provider) {		final Propagation propagation = new Propagation(net);		DoubleMatrix input = provider.getInput();		DoubleMatrix target = provider.getTarget();		final int allLen = target.columns;		final int[] nodesNum = net.getNodesNum();		final int layersNum = net.getLayersNum();		List<DoubleMatrix> inputBatches = this.getBatches(input);		final List<DoubleMatrix> targetBatches = this.getBatches(target);		final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches);		final BackwardResult backwardResult = new BackwardResult(net, allLen);          // 分批并行训练		Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() {			@Override			public void perform(int index, DoubleMatrix subInput) {				ForwardResult subResult = propagation.forward(subInput);				DoubleMatrix subTarget = targetBatches.get(index);				BackwardResult backResult = propagation.backward(subTarget,						subResult);				DoubleMatrix cost = backwardResult.cost;				DoubleMatrix accuracy = backwardResult.accuracy;				DoubleMatrix inputDeltas = backwardResult.getInputDelta();				int start = index == 0 ? 0 : batchLen.get(index - 1);				int end = batchLen.get(index) - 1;				int[] cIndexs = ArraysHelper.makeArray(start, end);				cost.put(cIndexs, backResult.cost);				if (accuracy != null) {					accuracy.put(cIndexs, backResult.accuracy);				}				inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1),						  cIndexs, backResult.getInputDelta());				for (int i = 0; i < layersNum; i++) {					DoubleMatrix gradients = backwardResult.gradients.get(i);					DoubleMatrix biasGradients = backwardResult.biasGradients							.get(i);  					DoubleMatrix subGradients = backResult.gradients.get(i)							.muli(backResult.cost.columns);					DoubleMatrix subBiasGradients = backResult.biasGradients							.get(i).muli(backResult.cost.columns);					gradients.addi(subGradients);					biasGradients.addi(subBiasGradients);				}			}		});         // 求均值		for(DoubleMatrix gradient:backwardResult.gradients){			gradient.divi(allLen);		}		for(DoubleMatrix gradient:backwardResult.biasGradients){			gradient.divi(allLen);		}				// this.mergeBackwardResult(backResults, net, input.columns);		TrainResult trainResult = new TrainResult(null, backwardResult);		learner.learn(net, trainResult);		Double cost = backwardResult.getMeanCost();		Double accuracy = backwardResult.getMeanAccuracy();		if (cost != null)			costs.add(cost);		if (accuracy != null)			accuracys.add(accuracy);  		System.out.println(cost);		System.out.println(accuracy);	}	@Override	public void train(Net net, DataProvider provider) {		for (int i = 0; i < this.ecophs; i++) {			this.trainOne(net, provider);		}	}}

 

Learner

Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。

其迭代公式如下:

$$W(t+1)=W(t)+\Delta W(t)$$

$$\Delta W(t)=rate(t)(1-moment(t))G(t+1)+moment(t)\Delta W(t-1)$$

$$rate(t+1)=\begin{cases} rate(t)\times 1.05 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 0.01 & \mbox{else} \end{cases}$$

$$moment(t+1)=\begin{cases} 0.9 & \mbox{if } cost(t)<cost(t-1)\\ rate(t)\times 0.7 & \mbox{else if } cost(t)<cost(t-1)\times 1.04\\ 1-0.9 & \mbox{else} \end{cases}$$

示例代码如下:

public class MomentAdaptLearner implements Learner {	Net net;	double moment = 0.9;	double lmd = 1.05;	double preCost = 0;	double eta = 0.01;	double currentEta=eta;	double currentMoment=moment;	TrainResult preTrainResult;		public MomentAdaptLearner(double moment, double eta) {		super();		this.moment = moment;		this.eta = eta;		this.currentEta=eta;		this.currentMoment=moment;	}	@Override	public void learn(Net net, TrainResult trainResult) {		if (this.net == null)			init(net);				BackwardResult backwardResult = trainResult.backwardResult;		BackwardResult preBackwardResult = preTrainResult.backwardResult;		double cost=backwardResult.getMeanCost();		this.modifyParameter(cost);		System.out.println("current eta:"+this.currentEta);		System.out.println("current moment:"+this.currentMoment);		for (int j = 0; j < net.getLayersNum(); j++) {			DoubleMatrix weight = net.getWeights().get(j);			DoubleMatrix gradient = backwardResult.gradients.get(j);			gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi(					preBackwardResult.gradients.get(j).muli(this.currentMoment));			preBackwardResult.gradients.set(j, gradient);			weight.subi(gradient);						DoubleMatrix b = net.getBs().get(j);			DoubleMatrix bgradient = backwardResult.biasGradients.get(j);			bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi(					preBackwardResult.biasGradients.get(j).muli(this.currentMoment));			preBackwardResult.biasGradients.set(j, bgradient);			b.subi(bgradient);		}	}	public void modifyParameter(double cost){		if(cost<this.preCost){			this.currentEta*=1.05;			this.currentMoment=moment;		}else if(cost<1.04*this.preCost){			this.currentEta*=0.7;			this.currentMoment*=0.7;		}else{			this.currentEta=eta;			this.currentMoment=1-moment;		}		this.preCost=cost;	}	public void init(Net net) {		this.net =  net;		BackwardResult bResult = new BackwardResult();		for (DoubleMatrix weight : net.getWeights()) {			bResult.gradients.add(DoubleMatrix.zeros(weight.rows,					weight.columns));		}		for (DoubleMatrix b : net.getBs()) {			bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns));		}		preTrainResult=new TrainResult(null,bResult);	}}

现在,一个简单的神经网路从生成到训练已经简单实现完毕。

用java写bp神经网络(二)