
org.jpmml.sparkml.feature.OneHotEncoderModelConverter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-sparkml Show documentation
Show all versions of pmml-sparkml Show documentation
JPMML Apache Spark ML to PMML converter
/*
* Copyright (c) 2018 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see .
*/
package org.jpmml.sparkml.feature;
import java.util.ArrayList;
import java.util.List;
import com.google.common.collect.Iterables;
import org.apache.spark.ml.feature.OneHotEncoderModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.sparkml.BinarizedCategoricalFeature;
import org.jpmml.sparkml.MultiFeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
public class OneHotEncoderModelConverter extends MultiFeatureConverter {
public OneHotEncoderModelConverter(OneHotEncoderModel transformer){
super(transformer);
}
@Override
public List encodeFeatures(SparkMLEncoder encoder){
OneHotEncoderModel transformer = getTransformer();
boolean dropLast = transformer.getDropLast();
InOutMode inputMode = getInputMode();
List result = new ArrayList<>();
String[] inputCols = inputMode.getInputCols(transformer);
for(String inputCol : inputCols){
CategoricalFeature categoricalFeature = (CategoricalFeature)encoder.getOnlyFeature(inputCol);
List> values = categoricalFeature.getValues();
List binaryFeatures = OneHotEncoderModelConverter.encodeFeature(encoder, categoricalFeature, values, dropLast);
result.add(new BinarizedCategoricalFeature(encoder, categoricalFeature, binaryFeatures));
}
return result;
}
@Override
public void registerFeatures(SparkMLEncoder encoder){
OneHotEncoderModel transformer = getTransformer();
List features = encodeFeatures(encoder);
InOutMode outputMode = getOutputMode();
if(outputMode == InOutMode.SINGLE){
String outputCol = transformer.getOutputCol();
BinarizedCategoricalFeature binarizedCategoricalFeature = (BinarizedCategoricalFeature)Iterables.getOnlyElement(features);
encoder.putFeatures(outputCol, (List)binarizedCategoricalFeature.getBinaryFeatures());
} else
if(outputMode == InOutMode.MULTIPLE){
String[] outputCols = transformer.getOutputCols();
SchemaUtil.checkSize(outputCols.length, features);
for(int i = 0; i < outputCols.length; i++){
String outputCol = outputCols[i];
Feature feature = features.get(i);
BinarizedCategoricalFeature binarizedCategoricalFeature = (BinarizedCategoricalFeature)feature;
encoder.putFeatures(outputCol, (List)binarizedCategoricalFeature.getBinaryFeatures());
}
}
}
static
public List encodeFeature(PMMLEncoder encoder, Feature feature, List> values, boolean dropLast){
List result = new ArrayList<>();
if(dropLast){
values = values.subList(0, values.size() - 1);
}
for(Object value : values){
result.add(new BinaryFeature(encoder, feature, value));
}
return result;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy