com.alibaba.alink.operator.batch.dataproc.vector.VectorStandardScalerTrainBatchOp Maven / Gradle / Ivy
package com.alibaba.alink.operator.batch.dataproc.vector;
import com.alibaba.alink.operator.common.dataproc.vector.VectorStandardScalerModelDataConverter;
import com.alibaba.alink.operator.common.statistics.basicstatistic.BaseVectorSummary;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.types.Row;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.statistics.StatisticsHelper;
import org.apache.flink.ml.api.misc.param.Params;
import com.alibaba.alink.params.dataproc.vector.VectorStandardTrainParams;
import org.apache.flink.util.Collector;
/**
* StandardScaler transforms a dataSet, normalizing each feature to have unit standard deviation and/or zero mean.
* If withMean is false, set mean as 0; if withStd is false, set std as 1.
*/
public final class VectorStandardScalerTrainBatchOp extends BatchOperator
implements VectorStandardTrainParams {
public VectorStandardScalerTrainBatchOp() {
this(new Params());
}
public VectorStandardScalerTrainBatchOp(Params params) {
super(params);
}
@Override
public VectorStandardScalerTrainBatchOp linkFrom(BatchOperator... inputs) {
BatchOperator in = checkAndGetFirst(inputs);
String vectorColName = getSelectedCol();
VectorStandardScalerModelDataConverter converter = new VectorStandardScalerModelDataConverter();
converter.vectorColName = vectorColName;
DataSet rows = StatisticsHelper.vectorSummary(in, vectorColName)
.flatMap(new BuildVectorStandardModel(vectorColName, getWithMean(), getWithStd()));
setOutput(rows, converter.getModelSchema());
return this;
}
/**
* table summary build model.
*/
public static class BuildVectorStandardModel implements FlatMapFunction {
private String selectedColName;
private boolean withMean;
private boolean withStd;
public BuildVectorStandardModel(String selectedColName, boolean withMean, boolean withStd) {
this.selectedColName = selectedColName;
this.withMean = withMean;
this.withStd = withStd;
}
@Override
public void flatMap(BaseVectorSummary srt, Collector collector) throws Exception {
if (null != srt) {
VectorStandardScalerModelDataConverter converter = new VectorStandardScalerModelDataConverter();
converter.vectorColName = selectedColName;
converter.save(Tuple3.of(withMean, withStd, srt), collector);
}
}
}
}