com.flipkart.fdp.ml.adapter.OneHotEncoderModelInfoAdapter Maven / Gradle / Ivy
package com.flipkart.fdp.ml.adapter;
import com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeType;
import org.apache.spark.ml.attribute.BinaryAttribute;
import org.apache.spark.ml.attribute.NominalAttribute;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.sql.DataFrame;
import java.util.LinkedHashSet;
import java.util.Set;
/**
* Transforms Spark's {@link OneHotEncoder} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo} object
* that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter}
Exporting Spark's OHE is ugly.
{@link com.flipkart.fdp.ml.CustomOneHotEncoder}
*/
public class OneHotEncoderModelInfoAdapter extends AbstractModelInfoAdapter {
@Override
public OneHotEncoderModelInfo getModelInfo(final OneHotEncoder from, DataFrame df) {
OneHotEncoderModelInfo modelInfo = new OneHotEncoderModelInfo();
String inputColumn = from.getInputCol();
//Ugly but the only way to deal with spark here
int numTypes = -1;
Attribute attribute = Attribute.fromStructField(df.schema().apply(inputColumn));
if (attribute.attrType() == AttributeType.Nominal()) {
numTypes = ((NominalAttribute) Attribute.fromStructField(df.schema().apply(inputColumn))).values().get().length;
} else if (attribute.attrType() == AttributeType.Binary()) {
numTypes = ((BinaryAttribute) Attribute.fromStructField(df.schema().apply(inputColumn))).values().get().length;
}
//TODO: Since dropLast is not accesible here, We are deliberately setting numTypes. This is the reason, we should use CustomOneHotEncoder
modelInfo.setNumTypes(numTypes - 1);
Set inputKeys = new LinkedHashSet();
inputKeys.add(from.getInputCol());
modelInfo.setInputKeys(inputKeys);
Set outputKeys = new LinkedHashSet();
outputKeys.add(from.getOutputCol());
modelInfo.setOutputKeys(outputKeys);
return modelInfo;
}
@Override
public Class getSource() {
return OneHotEncoder.class;
}
@Override
public Class getTarget() {
return OneHotEncoderModelInfo.class;
}
}