
org.jpmml.sparkml.feature.SQLTransformerConverter Maven / Gradle / Ivy
/*
* Copyright (c) 2018 Villu Ruusmann
*
* This file is part of JPMML-SparkML
*
* JPMML-SparkML is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SparkML is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SparkML. If not, see .
*/
package org.jpmml.sparkml.feature;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.feature.SQLTransformer;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.VisitorAction;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.TypeUtil;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sparkml.AliasExpression;
import org.jpmml.sparkml.DatasetUtil;
import org.jpmml.sparkml.ExpressionTranslator;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import scala.collection.JavaConversions;
public class SQLTransformerConverter extends FeatureConverter {
public SQLTransformerConverter(SQLTransformer sqlTransformer){
super(sqlTransformer);
}
@Override
public List encodeFeatures(SparkMLEncoder encoder){
SQLTransformer transformer = getTransformer();
String statement = transformer.getStatement();
SparkSession sparkSession = SparkSession.builder()
.getOrCreate();
StructType schema = encoder.getSchema();
LogicalPlan logicalPlan = DatasetUtil.createAnalyzedLogicalPlan(sparkSession, schema, statement);
List result = new ArrayList<>();
List> objects = encodeLogicalPlan(encoder, logicalPlan);
for(Object object : objects){
if(object instanceof List){
List> features = (List>)object;
features.stream()
.map(Feature.class::cast)
.forEach(result::add);
} else
if(object instanceof Field){
Field> field = (Field>)object;
String name = field.requireName();
Feature feature = encoder.createFeature(field);
encoder.putOnlyFeature(name, feature);
result.add(feature);
} else
{
throw new IllegalArgumentException();
}
}
return result;
}
@Override
public void registerFeatures(SparkMLEncoder encoder){
encodeFeatures(encoder);
}
static
public List> encodeLogicalPlan(SparkMLEncoder encoder, LogicalPlan logicalPlan){
List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy