□ Isotonic Regression
정확도를 향상 시키기 위해 예측값을 보정(캘리브레이션) 하는 방법에 Isotonic Regression이 많이 사용된다.
Spark MLlib에서는 이 기능이 지원이 된다.
아래 그림은 이 방법이 진행되는 원리를 한눈에 확인할 수 있다.
□ Sample Data
샘플 데이터는 Spark 에서 제공되는 Sample File인 sample_linear_regression_data.txt 을 사용했다.
0.24579296 1:0.01 0.28505864 1:0.02 0.31208567 1:0.03 0.35900051 1:0.04 0.35747068 1:0.05 0.16675166 1:0.06 0.17491076 1:0.07 0.04181540 1:0.08 0.04793473 1:0.09 0.03926568 1:0.10 0.12952575 1:0.11 0.00000000 1:0.12 0.01376849 1:0.13 0.13105558 1:0.14 0.08873024 1:0.15 0.12595614 1:0.16 0.15247323 1:0.17 0.25956145 1:0.18 0.20040796 1:0.19 0.19581846 1:0.20 0.15757267 1:0.21 ... ... ...
|
□ 예제
SparkConf sconf = new SparkConf().setMaster("local[2]") .setAppName("asdf") .set("spark.ui.port", "4041"); JavaSparkContext jsc = new JavaSparkContext(sconf); 데이터를 읽고 JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile( jsc.sc(), "data/mllib/sample_isotonic_regression_libsvm_data.txt") .toJavaRDD();
JavaRDD<Tuple3<Double, Double, Double>> parsedData = data.map( new Function<LabeledPoint, Tuple3<Double, Double, Double>>() { public Tuple3<Double, Double, Double> call(LabeledPoint point) { return new Tuple3<>(new Double(point.label()), new Double(point.features().apply(0)), 1.0); } } ); 6:4로 트레이닝, 테스트 데이터로 나누고 parsedData.randomSplit(new double[]{0.6, 0.4}, 11L); JavaRDD<Tuple3<Double, Double, Double>> training = splits[0]; JavaRDD<Tuple3<Double, Double, Double>> test = splits[1]; final IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); 예상 값과 실제 값 쌍 JavaPairRDD<Double, Double> predictionAndLabel = test.mapToPair( new PairFunction<Tuple3<Double, Double, Double>, Double, Double>() { @Override public Tuple2<Double, Double> call(Tuple3<Double, Double, Double> point) { Double predictedLabel = model.predict(point._2()); return new Tuple2<>(predictedLabel, point._1()); } } ); 평균 제곱 오차로 확인 Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( new Function<Tuple2<Double, Double>, Object>() { @Override public Object call(Tuple2<Double, Double> pl) { return Math.pow(pl._1() - pl._2(), 2); } } ).rdd()).mean(); System.out.println("Mean Squared Error = " + meanSquaredError);
|
□ 결과
Mean Squared Error = 0.005941473430284862 |
'데이터 분석' 카테고리의 다른 글
Spark - AssociationRules (0) | 2017.07.21 |
---|---|
Spark - Multiclass classification (0) | 2017.07.20 |
Spark - FP growth (FP tree) (0) | 2017.06.23 |
Spark - Hypothesis testing ( chi-squared test ) (0) | 2017.06.22 |
Spark - Correlations (2) (0) | 2017.06.21 |