com.datastax.insight.ml.spark.ml.classification.MultilayerPerceptronClassifierWrapper Maven / Gradle / Ivy
package com.datastax.insight.ml.spark.ml.classification;
import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.Consts;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
/**
* 多层感知机分类器
*/
public class MultilayerPerceptronClassifierWrapper implements DataSetOperator {
public static MultilayerPerceptronClassifier getOperator(String featuresCol,
String labelCol,
String layer,
Integer blockSize,
Long seed,
Integer maxIterations,
Double tol,
Double stepSize,
String solver) {
MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier();
if (!Strings.isNullOrEmpty(featuresCol)) {
classifier.setFeaturesCol(featuresCol);
}
if (!Strings.isNullOrEmpty(labelCol)) {
classifier.setLabelCol(labelCol);
}
if (!Strings.isNullOrEmpty(layer)) {
String[] ls = layer.split(Consts.DELIMITER);
int[] layers = new int[ls.length];
for (int i = 0; i < layers.length; i++) {
layers[i] = Integer.parseInt(ls[i]);
}
classifier.setLayers(layers);
}
if (blockSize != null) {
classifier.setBlockSize(blockSize);
}
if (seed != null) {
classifier.setSeed(seed);
}
if (maxIterations != null) {
classifier.setMaxIter(maxIterations);
}
if (tol != null) {
classifier.setTol(tol);
}
if (stepSize != null) {
classifier.setStepSize(stepSize);
}
if (!Strings.isNullOrEmpty(solver)) {
classifier.setSolver(solver);
}
return classifier;
}
public static MultilayerPerceptronClassificationModel fit(Dataset data,
String featuresCol,
String labelCol,
String layer,
Integer blockSize,
Long seed,
Integer maxIterations,
Double tol,
Double stepSize,
String solver) {
MultilayerPerceptronClassifier classifier = getOperator(featuresCol, labelCol, layer, blockSize, seed, maxIterations, tol, stepSize, solver);
return classifier.fit(data);
}
public static MultilayerPerceptronClassificationModel fit(MultilayerPerceptronClassifier classifier, Dataset data) {
return classifier.fit(data);
}
public static Dataset transform(MultilayerPerceptronClassificationModel model, Dataset data) {
return model.transform(data);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy