com.flipkart.fdp.ml.adapter.LogisticRegressionModelInfoAdapter1 Maven / Gradle / Ivy
package com.flipkart.fdp.ml.adapter;
import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.sql.DataFrame;
import java.util.LinkedHashSet;
import java.util.Set;
/**
* Transforms Spark's {@link LogisticRegressionModel} to {@link LogisticRegressionModelInfo} object
* that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter}
*/
@Slf4j
public class LogisticRegressionModelInfoAdapter1
extends AbstractModelInfoAdapter {
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) {
final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
logisticRegressionModelInfo.setWeights(sparkLRModel.coefficients().toArray());
logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
logisticRegressionModelInfo.setThreshold(sparkLRModel.getThreshold());
logisticRegressionModelInfo.setProbabilityKey(sparkLRModel.getProbabilityCol());
Set inputKeys = new LinkedHashSet();
inputKeys.add(sparkLRModel.getFeaturesCol());
logisticRegressionModelInfo.setInputKeys(inputKeys);
Set outputKeys = new LinkedHashSet();
outputKeys.add(sparkLRModel.getPredictionCol());
outputKeys.add(sparkLRModel.getProbabilityCol());
logisticRegressionModelInfo.setOutputKeys(outputKeys);
return logisticRegressionModelInfo;
}
@Override
public Class getSource() {
return LogisticRegressionModel.class;
}
@Override
public Class getTarget() {
return LogisticRegressionModelInfo.class;
}
}