首页 > 代码库 > RBM

RBM

获取数据,放到List中

将数据集划分为训练集、验证集、测试集

 

新建RBM对象,确定可见层、隐含层的大小

训练RBM

 

新建线程集

 

public static void train(SGDBase sgd, List<SampleVector> samples, SGDTrainConfig config) {
    int xy_n = (int) samples.size();
    int nrModelReplica = config.getNbrModelReplica();

//划分数据集
    HashMap<Integer, List<SampleVector>> list_map = new HashMap<Integer, List<SampleVector>>();
    for (int i = 0; i < nrModelReplica; i++) {
        list_map.put(i, new ArrayList<SampleVector>());
    }
    Random rand = new Random(System.currentTimeMillis());
    for (SampleVector v: samples) {
        int id = rand.nextInt(nrModelReplica);
        list_map.get(id).add(v);
    }
    //新建线程,并且给线程赋数据
    List<DeltaThread> threads = new ArrayList<DeltaThread>();
    List<LossThread> loss_threads = new ArrayList<LossThread>();
    for (int i = 0; i < nrModelReplica; i++) {
        threads.add(new DeltaThread(sgd, config, list_map.get(i)));
        loss_threads.add(new LossThread(sgd));
    }

    // start iteration
    for (int epoch = 1; epoch <= config.getMaxEpochs(); epoch++) {
        // thread start
        for(DeltaThread thread : threads) {
            thread.train(epoch);
        }

        // waiting for all stop
        while (true) {
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                break;
            }
            boolean allStop = true;
            for(DeltaThread thread : threads) {
                if (thread.isRunning()) {
                    allStop = false;
                    break;
                }
            }
            if (allStop) {
                break;
            }
        }

        // update
        for(DeltaThread thread : threads) {
            sgd.mergeParam(thread.getParam(), nrModelReplica);
        }

        logger.info("train done for this iteration-" + epoch);

        /**
         * 1 parameter output
         */
        if(config.isParamOutput() && (0 == (epoch % config.getParamOutputStep()))) {
            SGDPersistableWrite.output(config.getParamOutputPath(), sgd);
        }
       
        /**
         * 2 loss print
         */
        if(!config.isPrintLoss()) {
            continue;
        }
        if (0 != (epoch % config.getLossCalStep())) {
            continue;
        }

        // sum loss
        for (int i = 0; i < nrModelReplica; i++) {
            loss_threads.get(i).sumLoss(threads.get(i).getSamples());
        }

        // waiting for all stop
        while (true) {
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                break;
            }
            boolean allStop = true;
            for(LossThread thread : loss_threads) {
                if (thread.isRunning()) {
                    allStop = false;
                    break;
                }
            }
            if (allStop) {
                break;
            }
        }

        // sum up
        double totalError = 0;
        for(LossThread thread : loss_threads) {
            totalError += thread.getError();
        }
        totalError /= xy_n;
        logger.info("iteration-" + epoch + " done, total error is " + totalError);
        if (totalError <= config.getMinLoss()) {
            break;
        }
    }
}

RBM