본문 바로가기

데이터 분석

Spark - Random Forest Classification

336x280(권장), 300x250(권장), 250x250, 200x200 크기의 광고 코드만 넣을 수 있습니다.




간략 설명



기계 학습에서의 랜덤 포레스트(영어: random forest)는 분류, 회귀 분석 등에 사용되는 앙상블 학습 방법의 일종으로, 훈련 과정에서 구성한 다수의 결정 트리로부터 부류(분류) 또는 평균 예측치(회귀 분석)를 출력함으로써 동작한다.

▶ 랜덤 포레스트는 여러 개의 결정 트리들을 임의적으로 학습하는 방식의 앙상블 방법이다.


랜덤 포레스트 방법은 크게 다수의 결정 트리를 구성하는 학습 단계와 입력 벡터가 들어왔을 때, 분류하거나 예측하는 테스트 단계로 구성되어있다. 랜덤 포레스트는 검출, 분류, 그리고 회귀 등 다양한 애플리케이션으로 활용되고 있다.

 





그림 요약





          N 개의 결정트리로부터 얻어진 결과를

             평균, 곱하기, 또는 과반수 투표 방식을

             통해 최종 결과를 도출



 +

 +

 =

 




 





결과부터 보기 ( Spark Random Forest )


  샘플 데이터로 10개 결정트리 생성(개수는 사용자가 설정함)시 아래 결과


    ..

    ..

    ..  총 10개 Tree



  테스트 데이터 모델에 넣으면 10개의 Tree에서 투표방식을 이용하여 결과 뽑음






샘플 Data

             

0 1:1 2:4 3:1 4:1 5:1 6:3
0 1:1 2:1 3:1 4:1 5:1 6:6
1 1:2 2:1 3:5 4:1 5:1 6:6
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:2 2:3 3:1 4:1 5:1 6:1
1 1:2 2:3 3:5 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:3 6:1
0 1:1 2:4 3:1 4:1 5:3 6:1
0 1:1 2:1 3:1 4:1 5:1 6:6
1 1:2 2:1 3:5 4:1 5:1 6:1
0 1:1 2:1 3:1 4:1 5:1 6:1
1 1:1 2:3 3:1 4:1 5:1 6:5
1 1:2 2:3 3:5 4:1 5:1 6:2
0 1:1 2:4 3:1 4:1 5:3 6:1
0 1:1 2:1 3:1 4:1 5:1 6:6
1 1:2 2:1 3:1 4:1 5:1 6:1
2 1:2 2:6 3:9 4:1 5:1 6:8
1 1:2 2:6 3:9 4:1 5:1 6:8
2 1:2 2:6 3:9 4:5 5:1 6:8
2 1:2 2:6 3:9 4:6 5:1 6:8
2 1:2 2:6 3:9 4:4 5:1 6:8




코드


        
                SparkConf sparkConf = new SparkConf().setAppName("RandomForestExample")

                                                           .setMaster("local[2]").set("spark.ui.port", "4048");
               
                JavaSparkContext jsc = new JavaSparkContext(sparkConf);
               
               
                String training_path = "/home/ksu/Downloads/trainingValues.txt";
                String test_path = "/home/ksu/Downloads/testValues.txt";
               
                JavaRDD<LabeledPoint> training_data = MLUtils.loadLibSVMFile(jsc.sc(), training_path).toJavaRDD();
                JavaRDD<LabeledPoint> test_data = MLUtils.loadLibSVMFile(jsc.sc(), test_path).toJavaRDD();



                Integer numClasses = 3;
                HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
                Integer numTrees = 10;
                String featureSubsetStrategy = "auto";               
                String impurity = "gini";
                Integer maxDepth = 5;
                Integer maxBins = 100;
                Integer seed = 12345;

                final RandomForestModel model = RandomForest.trainClassifier(training_data, numClasses,
                  categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
                  seed);
                               
              
                JavaPairRDD<Double, Double> predictionAndLabel =
                        test_data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
                            @Override
                            public Tuple2<Double, Double> call(LabeledPoint p) {
                                System.out.println(model.predict(p.features())+" : "+p.label());
                                return new Tuple2<>(model.predict(p.features()), p.label());
                            }
                        });
               
               
                Double testErr =
                        1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
                            @Override
                            public Boolean call(Tuple2<Double, Double> pl) {
                                return !pl._1().equals(pl._2());
                            }
                        }).count() / test_data.count();


                System.out.println("Test Error: " + testErr);
               
                jsc.stop();





참고 사항

 

1. Integer numClasses = 3;

    데이터를 몇개의 분류로 나눌 것인가?


2. Integer numTrees = 10;

    랜덤 포레스트에 포함될 desicion trees의 수


3. String featureSubsetStrategy = "auto"; 

     각 노드에서 분할에 고려해야 할 기능의 수

    지원되는 값 : "auto", "all", "sqrt", "log2", "onethird".

    "auto"가 설정되면 numTrees를 기준으로이 매개 변수가 설정

    numTrees == 1 인 경우 "all"로 설정

    numTrees가 1보다 큰 경우 "sqrt"로 설정


4. String impurity = "gini";


     정보 획득 계산에 사용되는 기준.

    지원되는 값 : "gini"(권장) or "entropy".


5. Integer maxDepth = 5;  

    트리의 최대 깊이(권장 값 : 4)


6. Integer maxBins = 100;


    기능 분할에 사용되는 최대 빈 수 (권장 값 : 100)


7. Integer seed = 12345;


    부트 스트래핑 및 피처 서브 세트 선택을위한 랜덤 시드

 








테스트 결과


Test Error: 0.09523809523809523


예측 라벨 - 실제 라벨

1.0 : 1.0
1.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
0.0 : 0.0
0.0 : 0.0
1.0 : 1.0
1.0 : 1.0
2.0 : 2.0
1.0 : 1.0
2.0 : 1.0
0.0 : 0.0
2.0 : 2.0
2.0 : 2.0
0.0 : 0.0
2.0 : 2.0
0.0 : 0.0
1.0 : 1.0


2개 빼고 다 맞음











'데이터 분석' 카테고리의 다른 글

Spark - Multilayer perceptron classifier  (0) 2017.05.24
Spark - Random Forest Regression  (0) 2017.05.24
Spark - Linear Regression  (0) 2017.05.19
MSE, RMSE  (0) 2017.05.19
Spark - Naive Bayes (단어를 이용한 문서 구별)  (0) 2017.05.12