com.flipkart.fdp.ml.adapter.StringIndexerModelInfoAdapter Maven / Gradle / Ivy
The newest version!
package com.flipkart.fdp.ml.adapter;
import com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.sql.DataFrame;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
/**
* Transforms Spark's {@link StringIndexerModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo} object
* that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter}
*/
public class StringIndexerModelInfoAdapter extends AbstractModelInfoAdapter {
@Override
public StringIndexerModelInfo getModelInfo(final StringIndexerModel from, DataFrame df) {
final String[] labels = from.labels();
final Map labelToIndex = new HashMap();
for (int i = 0; i < labels.length; i++) {
labelToIndex.put(labels[i], (double) i);
}
final StringIndexerModelInfo modelInfo = new StringIndexerModelInfo();
modelInfo.setLabelToIndex(labelToIndex);
Set inputKeys = new LinkedHashSet();
inputKeys.add(from.getInputCol());
modelInfo.setInputKeys(inputKeys);
Set outputKeys = new LinkedHashSet();
outputKeys.add(from.getOutputCol());
modelInfo.setOutputKeys(outputKeys);
return modelInfo;
}
@Override
public Class getSource() {
return StringIndexerModel.class;
}
@Override
public Class getTarget() {
return StringIndexerModelInfo.class;
}
}