手写数字识别
1. BP神经网络
1.1 BP算法原理
BP即反向传播算法。利用输出后的误差计算上一层的误差,以此类推,直到输入层,然后再动态调节各层之间的连接权值来减小误差。
1.2 三层网络结构
1.2 Sigmoid激活函数
Sigmoid : f(x)=1/(1+e^−x)
f(x)的取值范围为0~1;
1.3 反向误差的计算及权值调整
2. 基于BP的手写数字识别
2.1 格式化输入
选取Mnist数据集中的0~9手写数字图片各28张。
将图像矩阵转化为BP网络可接受的一维数组形式:
//获取像素数组
private double[] getImagePixel(String image) throws Exception {
File file = new File(image);
BufferedImage bi = null;
try {
bi = ImageIO.read(file);
} catch (Exception e) {
e.printStackTrace();
}
int width = bi.getWidth();
int height = bi.getHeight();
double[] vector = new double[width / 2 * height / 2];
int k = 0;
for (int i = 0; i < width; i += 2) {
for (int j = 0; j < height; j += 2) {
int whiteNum = 0;
for (int m = 0; m < 2; m++) {
for (int n = 0; n < 2; n++) {
if (isWhite(bi, i + m, j + n)) {
whiteNum++;
}
}
}
vector[k++] = whiteNum / 4;
}
}
// System.out.println(Arrays.toString(vector));
return vector;
}
为了减小运算量,选取一个2×2的矩阵区域,统计区域
内白色像素点的个数,将数据量缩小为原来的1/4。
//二值化并归一化数据
private boolean isWhite(BufferedImage bi, int i, int j) {
int pixel = bi.getRGB(i, j);
int[] rgb = new int[3];
rgb[0] = (pixel & 0xff0000) >> 16;
rgb[1] = (pixel & 0xff00) >> 8;
rgb[2] = (pixel & 0xff);
double d = (double) ((rgb[0] * 38 + rgb[1] * 75 + rgb[2] * 15) >> 7) > 100 ? 1 : 0;
if (d == 1) {
return true;
} else
return false;
}
2.创建神经网络并训练网络
对网络进行20000次迭代训练或直到误差小于0.001为止
String trianDataPath = "C:\Users\Administrator\Desktop\or_perceptron.nnet";
String imgPath = "data/train2/";
int maxLearn = 50000;
double maxError = 0.0001;
Stack<Long> stack = new Stack<>();
stack.push(System.currentTimeMillis());
System.out.println("->初始化多层网络...");
MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 100, 20, 10);
BackPropagation bp = new BackPropagation();
bp.setMaxIterations(maxLearn);
bp.setMaxError(maxError);
bp.setLearningRate(0.5);
bp.setMinErrorChange(0.0001);
//bp.setBatchMode(true);
// bp.setMinErrorChangeIterationsLimit(1000);
myMlPerceptron.setLearningRule(bp);
System.out.println("->初始化完成");
LearningRule learningRule = myMlPerceptron.getLearningRule();
learningRule.addListener(new LearningEventListener() {
@Override
public void handleLearningEvent(LearningEvent event) {
BackPropagation bp = (BackPropagation) event.getSource();
int iteration = bp.getCurrentIteration();
if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED && iteration % 100 == 0) {
System.out.print("->学习次数: " + iteration + ",当前误差: " + bp.getTotalNetworkError());
System.out.println(",用时:"+(System.currentTimeMillis()-stack.pop())+" ms");
stack.push(System.currentTimeMillis());
}
}
});
System.out.println("->创建数据集...");
DataSet trainingSet = new DataSet(100, 10);
// 添加训练数据到数据集
File f = new File(imgPath);
File[] list = f.listFiles();
ImageToVector itv = new ImageToVector();
for (int i = 0; i < list.length; i++) {
String fileName = list[i].getName();
double[] input = itv.imageToVector(imgPath + fileName, 20, 20);
trainingSet.addRow(input, getTarget(getNumber(fileName)));
}
System.out.println("->数据集创建完成");
System.out.println("->开始学习...");
myMlPerceptron.learn(trainingSet);
System.out.println("->学习完成");
System.out.println("->保存学习数据...");
myMlPerceptron.save(trianDataPath);
System.out.println("->保存完成...");
2.3 测试网络
测试共选取4990张图片,识别正确2646张,正确率:0.530
2605。识别率较低。