获取数据,放到List中
将数据集划分为训练集、验证集、测试集
新建RBM对象,确定可见层、隐含层的大小
训练RBM
新建线程集
public static void train(SGDBase sgd, List<SampleVector> samples, SGDTrainConfig config) {
int xy_n = (int) samples.size();
int nrModelReplica = config.getNbrModelReplica();
//划分数据集
HashMap<Integer, List<SampleVector>> list_map = new HashMap<Integer, List<SampleVector>>();
for (int i = 0; i < nrModelReplica; i++) {
list_map.put(i, new ArrayList<SampleVector>());
}
Random rand = new Random(System.currentTimeMillis());
for (SampleVector v: samples) {
int id = rand.nextInt(nrModelReplica);
list_map.get(id).add(v);
}
//新建线程,并且给线程赋数据
List<DeltaThread> threads = new ArrayList<DeltaThread>();
List<LossThread> loss_threads = new ArrayList<LossThread>();
for (int i = 0; i < nrModelReplica; i++) {
threads.add(new DeltaThread(sgd, config, list_map.get(i)));
loss_threads.add(new LossThread(sgd));
}
// start iteration
for (int epoch = 1; epoch <= config.getMaxEpochs(); epoch++) {
// thread start
for(DeltaThread thread : threads) {
thread.train(epoch);
}
// waiting for all stop
while (true) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
break;
}
boolean allStop = true;
for(DeltaThread thread : threads) {
if (thread.isRunning()) {
allStop = false;
break;
}
}
if (allStop) {
break;
}
}
// update
for(DeltaThread thread : threads) {
sgd.mergeParam(thread.getParam(), nrModelReplica);
}
logger.info("train done for this iteration-" + epoch);
/**
* 1 parameter output
*/
if(config.isParamOutput() && (0 == (epoch % config.getParamOutputStep()))) {
SGDPersistableWrite.output(config.getParamOutputPath(), sgd);
}
/**
* 2 loss print
*/
if(!config.isPrintLoss()) {
continue;
}
if (0 != (epoch % config.getLossCalStep())) {
continue;
}
// sum loss
for (int i = 0; i < nrModelReplica; i++) {
loss_threads.get(i).sumLoss(threads.get(i).getSamples());
}
// waiting for all stop
while (true) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
break;
}
boolean allStop = true;
for(LossThread thread : loss_threads) {
if (thread.isRunning()) {
allStop = false;
break;
}
}
if (allStop) {
break;
}
}
// sum up
double totalError = 0;
for(LossThread thread : loss_threads) {
totalError += thread.getError();
}
totalError /= xy_n;
logger.info("iteration-" + epoch + " done, total error is " + totalError);
if (totalError <= config.getMinLoss()) {
break;
}
}
}