336x280(권장), 300x250(권장), 250x250, 200x200 크기의 광고 코드만 넣을 수 있습니다.
■ Spark Random Forest Impurity Info
▶ 위 3개의 정보는 불순도(impurity)를 측정하는 하나의 지수로 Spark MLlib Random Forest에서 지원하는 내용이다.
▶ 임의의 한 개체가 목표변수의 i번째
범주로부터 추출되었고, 그 개체를 목표변수의 j번째 범주에 속한다고 오분류(misclassification)할 확률은
P(i)P(j)가 된다.
▶ 여기에서 P(i)는 각 마디에서 한 개체가 목표변수의 I번째 범주에 속할 확률이다. 이러한 오분류 확률은 위의 식으로 모두 더하여 값을 얻을 수 있고, 위와 같은 분류규칙 하에서 오분류 확률의 추정치 가 된다.
▶ RandomForest Java docs 에서
▶ Classification의 경우 "gini"를 추천하고 있으며,
▶ Regression에서는 Variance가 지원된다는 걸 기억하자.
|
앞에서 Classification을 해보았으니 이번에는 Regression를 진행해 보자
Classification에서 데이터들을 특성 정보만으로 0, 1, 2 라는 라벨 중 하나로 분류하였다면 Regression에서는 예상값이 출력 될 것이라고 생각이 든다.
그럼 시작해보자.
샘플 데이터는 Random Forest Classification 에서 썼던 내용과 동일 하다.
있는 소스에서 일부분만 수정해서 테스트를 진행하면서 실수를 했다.
■ 에러 내용
Regression를 진행하는데 impurity를 "gini"를 사용했다. Regression의 경우 Variance를 쓰라고 에러 로그가 발생한다.
그래서 위에 내용을 먼저 언급했다.
Exception in thread "main"
java.lang.IllegalArgumentException: requirement failed: DecisionTree
Strategy given invalid impurity for Regression:
org.apache.spark.mllib.tree.impurity.Gini$@64b3b1ce. Valid settings: Variance at scala.Predef$.require(Predef.scala:224) at org.apache.spark.mllib.tree.configuration.Strategy.assertValid(Strategy.scala:147) at org.apache.spark.mllib.tree.RandomForest.<init>(RandomForest.scala:78) at org.apache.spark.mllib.tree.RandomForest$.trainRegressor(RandomForest.scala:217) at org.apache.spark.mllib.tree.RandomForest$.trainRegressor(RandomForest.scala:258) at org.apache.spark.mllib.tree.RandomForest$.trainRegressor(RandomForest.scala:274) at org.apache.spark.mllib.tree.RandomForest.trainRegressor(RandomForest.scala)
|
■ 예제 소스
▶ impurity : variance를 사용 / 모델 생성 : trainRegressor / 당연히 numClasses를 안함 SparkConf sparkConf1 = new SparkConf().setAppName("RandomForestExample") .setMaster("local[2]").set("spark.ui.port", "4048"); JavaSparkContext jsc = new JavaSparkContext(sparkConf1); String training_path = "/home/ksu/Downloads/trainingValues.txt"; String test_path = "/home/ksu/Downloads/testValues.txt"; JavaRDD<LabeledPoint> training_data = MLUtils.loadLibSVMFile(jsc.sc(), training_path).toJavaRDD(); JavaRDD<LabeledPoint> test_data = MLUtils.loadLibSVMFile(jsc.sc(), test_path).toJavaRDD();
HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>(); Integer numTrees = 3; // Use more in practice. String featureSubsetStrategy = "auto"; // Let the algorithm choose. String impurity = "variance"; Integer maxDepth = 4; Integer maxBins = 100; Integer seed = 12345;
final RandomForestModel model = RandomForest.trainRegressor(training_data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); JavaPairRDD<Double, Double> predictionAndLabel = test_data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { @Override public Tuple2<Double, Double> call(LabeledPoint p) { System.out.println(model.predict(p.features())+" : "+p.label()); return new Tuple2<>(model.predict(p.features()), p.label()); } }); Double testMSE = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { @Override public Double call(Tuple2<Double, Double> pl) { Double diff = pl._1() - pl._2(); return diff * diff; } }).reduce(new Function2<Double, Double, Double>() { @Override public Double call(Double a, Double b) { return a + b; } }) / test_data.count();
System.out.println("Test Error: " + testMSE); System.out.println(model.toDebugString()); jsc.stop();
|
■ 결과
예측값 : 실제값
0.6444444444444445 : 0.0 0.6444444444444445 : 1.0 0.9777777777777779 : 1.0 0.0 : 0.0 0.3844444444444444 : 0.0 1.2 : 1.0 0.0 : 0.0 0.0 : 0.0 1.0 : 1.0 0.9777777777777779 : 1.0 1.777777777777778 : 2.0 0.9777777777777779 : 1.0 1.777777777777778 : 1.0 0.0 : 0.0 2.0 : 2.0 2.0 : 2.0 0.3844444444444444 : 0.0 2.0 : 2.0 0.0 : 0.0 1.0 : 1.0 0.0 : 0.0
|
■ Error & Tree Info
Test Error: 0.07300599647266315 TreeEnsembleModel regressor with 5 trees
Tree 0: If (feature 2 <= 6.0) If (feature 0 <= 1.0) If (feature 2 <= 1.0) If (feature 1 <= 1.0) Predict: 0.0 Else (feature 1 > 1.0) Predict: 0.3333333333333333 Else (feature 2 > 1.0) Predict: 1.0 Else (feature 0 > 1.0) Predict: 1.0 Else (feature 2 > 6.0) Predict: 2.0 Tree 1: If (feature 3 <= 1.0) If (feature 0 <= 1.0) If (feature 5 <= 1.0) If (feature 1 <= 1.0) Predict: 0.0 Else (feature 1 > 1.0) Predict: 0.2 Else (feature 5 > 1.0) If (feature 1 <= 1.0) Predict: 0.0 Else (feature 1 > 1.0) Predict: 1.0 Else (feature 0 > 1.0) If (feature 5 <= 2.0) Predict: 1.0 Else (feature 5 > 2.0) Predict: 2.0 Else (feature 3 > 1.0) Predict: 2.0 Tree 2: If (feature 3 <= 1.0) If (feature 1 <= 1.0) If (feature 0 <= 1.0) If (feature 4 <= 1.0) If (feature 2 <= 1.0) Predict: 0.0 Else (feature 2 > 1.0) Predict: 1.0 Else (feature 4 > 1.0) Predict: 0.0 Else (feature 0 > 1.0) Predict: 1.0 Else (feature 1 > 1.0) Predict: 0.8888888888888888 Else (feature 3 > 1.0) Predict: 2.0 Tree 3: If (feature 1 <= 4.0) .. .. .. ..
|