Encog中有很多的训练方法。
EncogUtility是一个功能辅助类,提供了很多方便的函数
Modifier and Type | Method and Description |
---|---|
static double |
calculateClassificationError(MLClassification method, MLDataSet data)
Calculate the classification error.
|
static double |
calculateRegressionError(MLRegression method, MLDataSet data) |
static void |
convertCSV2Binary(File csvFile, CSVFormat format, File binFile, int[] input, int[] ideal, boolean headers) |
static void |
convertCSV2Binary(File csvFile, File binFile, int inputCount, int outputCount, boolean headers)
Convert a CSV file to a binary training file.
|
static void |
convertCSV2Binary(String csvFile, String binFile, int inputCount, int outputCount, boolean headers)
Convert a CSV file to a binary training file.
|
static void |
evaluate(MLRegression network, MLDataSet training)
Evaluate the network and display (to the console) the output for every value in the training set.
|
static void |
explainErrorMSE(MLRegression method, MatrixMLDataSet training) |
static void |
explainErrorRMS(MLRegression method, MatrixMLDataSet training) |
static String |
formatNeuralData(MLData data)
Format neural data as a list of numbers.
|
static MLDataSet |
loadCSV2Memory(String filename, int input, int ideal, boolean headers, CSVFormat format, boolean significance)
Load CSV to memory.
|
static MLDataSet |
loadEGB2Memory(File filename) |
static void |
saveCSV(File targetFile, CSVFormat format, MLDataSet set) |
static void |
saveEGB(File f, MLDataSet data)
Save a training set to an EGB file.
|
static BasicNetwork |
simpleFeedForward(int input, int hidden1, int hidden2, int output, boolean tanh)
Create a simple feedforward neural network.
|
static void |
trainConsole(BasicNetwork network, MLDataSet trainingSet, int minutes)
Train the neural network, using SCG training, and output status to the console.
|
static void |
trainConsole(MLTrain train, BasicNetwork network, MLDataSet trainingSet, int minutes)
Train the network, using the specified training algorithm, and send the output to the console.
|
static void |
trainToError(MLMethod method, MLDataSet dataSet, double error)
Train the method, to a specific error, send the output to the console.
|
static void |
trainToError(MLTrain train, double error)
Train to a specific error, using the specified training method, send the output to the console.
|
BasicTraining类是所有训练方法类的父类
构造函数 |
---|
BasicTraining()
Used for serialization.
|
BasicTraining(TrainingImplementationType implementationType) |
返回值 | 成员函数 |
---|---|
void |
addStrategy(Strategy strategy)
Training strategies can be added to improve the training results.
|
void |
finishTraining()
Should be called after training has completed and the iteration method will not be called any further.
|
double |
getError() |
TrainingImplementationType |
getImplementationType() |
int |
getIteration() |
List<Strategy> |
getStrategies() |
MLDataSet |
getTraining() |
boolean |
isTrainingDone() |
void |
iteration(int count)
Perform the specified number of training iterations.
|
void |
postIteration()
Call the strategies after an iteration.
|
void |
preIteration()
Call the strategies before an iteration.
|
void |
setError(double error) |
void |
setIteration(int iteration)
Set the current training iteration.
|
void |
setTraining(MLDataSet training)
Set the training object that this strategy is working with.
|
Backpropagation类是propagation类的子类
构造函数 |
---|
Backpropagation(ContainsFlat network, MLDataSet training)
Create a class to train using backpropagation.
|
第一个参数:将被训练的网络 第二个参数: 训练集 第三个参数:学习率 第四个参数: 梯度下降法中一种常用的加速技术。momentum是加速系数,momentum=0表示无加速,值越大表示加速越快。 |
返回值 | 成员函数 |
---|---|
boolean |
canContinue() |
double[] |
getLastDelta() |
double |
getLearningRate() |
double |
getMomentum() |
void |
initOthers()
Perform training method specific init.
|
boolean |
isValidResume(TrainingContinuation state)
Determine if the specified continuation object is valid to resume with.
|
TrainingContinuation |
pause()
Pause the training.
|
void |
resume(TrainingContinuation state)
Resume training.
|
void |
setLearningRate(double rate)
Set the learning rate, this is value is essentially a percent.
|
void |
setMomentum(double m)
Set the momentum for training.
|
double |
updateWeight(double[] gradients, double[] lastGradient, int index)
Update a weight.
|