본문 바로가기

데이터 분석

Spark - Random Forest Regression

336x280(권장), 300x250(권장), 250x250, 200x200 크기의 광고 코드만 넣을 수 있습니다.

Spark Random Forest Impurity Info





▶ 위 3개의 정보는 불순도(impurity)를 측정하는 하나의 지수로 Spark MLlib Random Forest에서 지원하는 내용이다.


임의의 한 개체가 목표변수의 i번째 범주로부터 추출되었고, 그 개체를 목표변수의 j번째 범주에 속한다고 오분류(misclassification)할 확률은 P(i)P(j)가 된다.


여기에서 P(i)는 각 마디에서 한 개체가 목표변수의 I번째 범주에 속할 확률이다. 이러한 오분류 확률은 위의 식으로 모두 더하여 값을 얻을 수 있고, 위와 같은 분류규칙 하에서 오분류 확률의 추정치 가 된다.





RandomForest Java docs 에서


Classification의 경우 "gini"를 추천하고 있으며,


Regression에서는 Variance가 지원된다는 걸 기억하자.





앞에서 Classification을 해보았으니 이번에는 Regression를 진행해 보자


Classification에서 데이터들을 특성 정보만으로 0, 1, 2 라는 라벨 중 하나로 분류하였다면 Regression에서는 예상값이 출력 될 것이라고 생각이 든다.


그럼 시작해보자.


샘플 데이터는 Random Forest Classification 에서 썼던 내용과 동일 하다.








있는 소스에서 일부분만 수정해서 테스트를 진행하면서 실수를 했다.



■ 에러 내용



Regression를 진행하는데 impurity를 "gini"를 사용했다. Regression의 경우 Variance를 쓰라고 에러 로그가 발생한다.


그래서 위에 내용을 먼저 언급했다.


Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: DecisionTree Strategy given invalid impurity for Regression: org.apache.spark.mllib.tree.impurity.Gini$@64b3b1ce.  Valid settings: Variance
    at scala.Predef$.require(Predef.scala:224)
    at org.apache.spark.mllib.tree.configuration.Strategy.assertValid(Strategy.scala:147)
    at org.apache.spark.mllib.tree.RandomForest.<init>(RandomForest.scala:78)
    at org.apache.spark.mllib.tree.RandomForest$.trainRegressor(RandomForest.scala:217)
    at org.apache.spark.mllib.tree.RandomForest$.trainRegressor(RandomForest.scala:258)
    at org.apache.spark.mllib.tree.RandomForest$.trainRegressor(RandomForest.scala:274)
    at org.apache.spark.mllib.tree.RandomForest.trainRegressor(RandomForest.scala)







예제 소스


     ▶ impurity : variance를 사용   /   모델 생성 : trainRegressor    /   당연히 numClasses를 안함


                SparkConf sparkConf1 = new SparkConf().setAppName("RandomForestExample")
                        .setMaster("local[2]").set("spark.ui.port", "4048");
               
                JavaSparkContext jsc = new JavaSparkContext(sparkConf1);
               
                String training_path = "/home/ksu/Downloads/trainingValues.txt";
                String test_path = "/home/ksu/Downloads/testValues.txt";
               
                JavaRDD<LabeledPoint> training_data = MLUtils.loadLibSVMFile(jsc.sc(), training_path).toJavaRDD();
                JavaRDD<LabeledPoint> test_data = MLUtils.loadLibSVMFile(jsc.sc(), test_path).toJavaRDD();
               

                HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
                Integer numTrees = 3; // Use more in practice.
                String featureSubsetStrategy = "auto"; // Let the algorithm choose.
                String impurity = "variance";
                Integer maxDepth = 4;
                Integer maxBins = 100;
                Integer seed = 12345;


                final RandomForestModel model = RandomForest.trainRegressor(training_data,
                  categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed);               
              
                JavaPairRDD<Double, Double> predictionAndLabel =
                        test_data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
                            @Override
                            public Tuple2<Double, Double> call(LabeledPoint p) {
                                System.out.println(model.predict(p.features())+" : "+p.label());
                                return new Tuple2<>(model.predict(p.features()), p.label());
                            }
                        });
                Double testMSE =
                          predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
                            @Override
                            public Double call(Tuple2<Double, Double> pl) {
                              Double diff = pl._1() - pl._2();
                              return diff * diff;
                            }
                          }).reduce(new Function2<Double, Double, Double>() {
                            @Override
                            public Double call(Double a, Double b) {
                              return a + b;
                            }
                          }) / test_data.count();


                System.out.println("Test Error: " + testMSE);               
                System.out.println(model.toDebugString());
                jsc.stop();










결과


예측값 : 실제값


0.6444444444444445 : 0.0
0.6444444444444445 : 1.0
0.9777777777777779 : 1.0
0.0 : 0.0
0.3844444444444444 : 0.0
1.2 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
0.9777777777777779 : 1.0
1.777777777777778 : 2.0
0.9777777777777779 : 1.0
1.777777777777778 : 1.0
0.0 : 0.0
2.0 : 2.0
2.0 : 2.0
0.3844444444444444 : 0.0
2.0 : 2.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0

 







Error & Tree Info



Test Error: 0.07300599647266315
TreeEnsembleModel regressor with 5 trees

  Tree 0:
    If (feature 2 <= 6.0)
     If (feature 0 <= 1.0)
      If (feature 2 <= 1.0)
       If (feature 1 <= 1.0)
        Predict: 0.0
       Else (feature 1 > 1.0)
        Predict: 0.3333333333333333
      Else (feature 2 > 1.0)
       Predict: 1.0
     Else (feature 0 > 1.0)
      Predict: 1.0
    Else (feature 2 > 6.0)
     Predict: 2.0
  Tree 1:
    If (feature 3 <= 1.0)
     If (feature 0 <= 1.0)
      If (feature 5 <= 1.0)
       If (feature 1 <= 1.0)
        Predict: 0.0
       Else (feature 1 > 1.0)
        Predict: 0.2
      Else (feature 5 > 1.0)
       If (feature 1 <= 1.0)
        Predict: 0.0
       Else (feature 1 > 1.0)
        Predict: 1.0
     Else (feature 0 > 1.0)
      If (feature 5 <= 2.0)
       Predict: 1.0
      Else (feature 5 > 2.0)
       Predict: 2.0
    Else (feature 3 > 1.0)
     Predict: 2.0
  Tree 2:
    If (feature 3 <= 1.0)
     If (feature 1 <= 1.0)
      If (feature 0 <= 1.0)
       If (feature 4 <= 1.0)
        If (feature 2 <= 1.0)
         Predict: 0.0
        Else (feature 2 > 1.0)
         Predict: 1.0
       Else (feature 4 > 1.0)
        Predict: 0.0
      Else (feature 0 > 1.0)
       Predict: 1.0
     Else (feature 1 > 1.0)
      Predict: 0.8888888888888888
    Else (feature 3 > 1.0)
     Predict: 2.0
  Tree 3:
    If (feature 1 <= 4.0)

               ..

               ..

               ..

               ..


 








'데이터 분석' 카테고리의 다른 글

Spark - ML Pipelines  (0) 2017.06.01
Spark - Multilayer perceptron classifier  (0) 2017.05.24
Spark - Random Forest Classification  (3) 2017.05.23
Spark - Linear Regression  (0) 2017.05.19
MSE, RMSE  (0) 2017.05.19