All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.datastax.insight.ml.spark.ml.regression.RandomForestRegressionWrapper Maven / Gradle / Ivy

package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 随机森林回归
 */
public class RandomForestRegressionWrapper implements DataSetOperator {

    public static RandomForestRegressor getOperator(String labelCol,
                                                    String featuresCol,
                                                    Integer maxDepth,
                                                    Integer maxBins,
                                                    Integer minInstancesPerNode,
                                                    Double minInfoGain,
                                                    Integer maxMemoryInMB,
                                                    Boolean cacheNodeIds,
                                                    Integer checkpointInterval,
                                                    String impurity,
                                                    Double subsamplingRate,
                                                    Integer numTrees,
                                                    String featureSubsetStrategy) {

        RandomForestRegressor regressor = new RandomForestRegressor();

        if (!Strings.isNullOrEmpty(labelCol)) {
            regressor.setLabelCol(labelCol);
        }

        if (!Strings.isNullOrEmpty(featuresCol)) {
            regressor.setFeaturesCol(featuresCol);
        }

        if (maxDepth != null) {
            regressor.setMaxDepth(maxDepth);
        }

        if (maxBins != null) {
            regressor.setMaxBins(maxBins);
        }

        if (minInstancesPerNode != null) {
            regressor.setMinInstancesPerNode(minInstancesPerNode);
        }

        if (minInfoGain != null) {
            regressor.setMinInfoGain(minInfoGain);
        }

        if (maxMemoryInMB != null) {
            regressor.setMaxMemoryInMB(maxMemoryInMB);
        }

        if (cacheNodeIds != null) {
            regressor.setCacheNodeIds(cacheNodeIds);
        }

        if (checkpointInterval != null) {
            regressor.setCheckpointInterval(checkpointInterval);
        }

        if (!Strings.isNullOrEmpty(impurity)) {
            regressor.setImpurity(impurity);
        }

        if (subsamplingRate != null) {
            regressor.setSubsamplingRate(subsamplingRate);
        }

        if (numTrees != null) {
            regressor.setNumTrees(numTrees);
        }

        if (!Strings.isNullOrEmpty(featureSubsetStrategy)) {
            regressor.setFeatureSubsetStrategy(featureSubsetStrategy);
        }

        return regressor;
    }

    public static RandomForestRegressionModel fit(Dataset data,
                                                  String labelCol,
                                                  String featuresCol,
                                                  Integer maxDepth,
                                                  Integer maxBins,
                                                  Integer minInstancesPerNode,
                                                  Double minInfoGain,
                                                  Integer maxMemoryInMB,
                                                  Boolean cacheNodeIds,
                                                  Integer checkpointInterval,
                                                  String impurity,
                                                  Double subsamplingRate,
                                                  Integer numTrees,
                                                  String featureSubsetStrategy) {
        RandomForestRegressor regressor = getOperator(labelCol,
                featuresCol,
                maxDepth,
                maxBins,
                minInstancesPerNode,
                minInfoGain,
                maxMemoryInMB,
                cacheNodeIds,
                checkpointInterval,
                impurity,
                subsamplingRate,
                numTrees,
                featureSubsetStrategy);
        return regressor.fit(data);
    }

    public static RandomForestRegressionModel fit(RandomForestRegressor regressor, Dataset data) {
        return regressor.fit(data);
    }

    public static Dataset transform(RandomForestRegressionModel model, Dataset data) {
        return model.transform(data);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy