//Java版本的线性回归的预测代码
package com.swust.machine;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.rdd.RDD;
import scala.Tuple2;
import java.util.List;
/**
*
* @author 雪瞳
* @Slogan 时钟尚且前行,人怎能再此止步!
* @Function 线性回归算法实现
*
*/
public class LinearRegression {
public static void main(String[] args) {
SparkConf conf = new SparkConf();
conf.setMaster("local").setAppName("line");
JavaSparkContext jsc = new JavaSparkContext(conf);
jsc.setLogLevel("Error");
// 读取样本数据
JavaRDD<String> data = jsc.textFile("./data/lpsa.data");
JavaRDD<LabeledPoint> examples = data.map(new Function<String, LabeledPoint>() {
@Override
public LabeledPoint call(String line) throws Exception {
String[] splits = line.split(",");
String y = splits[0];
String xs = splits[1];
String[] words = xs.split(" ");
double[] wd = new double[words.length];
for (int i = 0; i < words.length; i++) {
wd[i] = Double.parseDouble(words[i]);
}
return new LabeledPoint(Double.parseDouble(y),
Vectors.dense(wd));
}
});
//将数据集按比例切分为训练集和测试集
double[] doubles = new double[]{0.8,0.2};
RDD<LabeledPoint> rdd = examples.rdd();
RDD<LabeledPoint>[] TestData = rdd.randomSplit(doubles, 1L);
//设置迭代次数
int numIterations = 100;
//设置迭代过程中 梯度下降算法的下降步长大小
// 0.1 0.2 0.3 0.4
int stepSize = 1;
int miniBatchFraction = 1;
LinearRegressionWithSGD lrs = new LinearRegressionWithSGD();
//设置训练模型是否存在截距
lrs.setIntercept(true);
//设置步长
lrs.optimizer().setStepSize(stepSize);
//设置迭代次数
lrs.optimizer().setNumIterations(numIterations);
//计算所有样本的误差值,1代表所有样本,默认1.0
lrs.optimizer().setMiniBatchFraction(miniBatchFraction);
//GeneralizedLinearAlgorithm
LinearRegressionModel model = lrs.run(TestData[0]);
System.err.println(model.weights());
System.err.println(model.intercept());
//对样本的测试
JavaRDD<Double> prediction = model.predict(TestData[1].toJavaRDD().map(new Function<LabeledPoint, Vector>() {
@Override
public Vector call(LabeledPoint labeledPoint) throws Exception {
return labeledPoint.features();
}
}));
//压缩样本
JavaPairRDD<Double, Double> predictionAndLabel = prediction.zip(TestData[1].toJavaRDD().map(new Function<LabeledPoint, Double>() {
@Override
public Double call(LabeledPoint labeledPoint) throws Exception {
return labeledPoint.label();
}
}));
//数据分析 取其中20条
List<Tuple2<Double, Double>> take = predictionAndLabel.take(20);
//预测 标签
System.err.println("prediction"+" "+"label");
for (Tuple2<Double, Double> elem:take){
System.out.println(elem._1()+" "+elem._2());
}
//计算数据的平均误差
JavaRDD<Double> dataLoss = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
@Override
public Double call(Tuple2<Double, Double> one) throws Exception {
double err = one._1() - one._2();
return Math.abs(err);
}
});
Double lossResult = dataLoss.reduce(new Function2<Double, Double, Double>() {
@Override
public Double call(Double aDouble, Double aDouble2) throws Exception {
return aDouble + aDouble2;
}
});
double err = lossResult / TestData[1].count();
System.err.println("Test RMSE"+err);
jsc.stop();
}
}
//由于数据量本身只有100条 所以预测效果相对较差 但是重要的是思路不是嘛
// 有道无术术可求 有术无道止于术 学会一个思想更为重要