All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.alink.operator.common.regression.IsotonicRegressionModelMapper Maven / Gradle / Ivy

package com.alibaba.alink.operator.common.regression;

import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.table.api.Types;
import org.apache.flink.types.Row;

import com.alibaba.alink.operator.common.dataproc.ScalerUtil;
import com.alibaba.alink.common.linalg.VectorUtil;
import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.RowUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.regression.IsotonicRegTrainParams;
import com.alibaba.alink.params.shared.colname.HasPredictionCol;
import org.apache.commons.lang3.ArrayUtils;

import java.util.Arrays;
import java.util.List;

/**
 * This mapper predicts the isotonic regression result.
 */
public class IsotonicRegressionModelMapper extends ModelMapper {
	private int colIdx;
	private IsotonicRegressionModelData modelData;
	private String vectorColName;
	private int featureIndex;

	/**
	 * Constructor.
	 *
	 * @param modelSchema the model schema.
	 * @param dataSchema  the data schema.
	 * @param params      the params.
	 */
	public IsotonicRegressionModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
		super(modelSchema, dataSchema, params);
	}

	/**
	 * Load model from the list of Row type data.
	 *
	 * @param modelRows the list of Row type data.
	 */
	@Override
	public void loadModel(List  modelRows) {
		IsotonicRegressionConverter converter = new IsotonicRegressionConverter();
		modelData = converter.load(modelRows);
		Params meta = modelData.meta;
		String featureColName = meta.get(IsotonicRegTrainParams.FEATURE_COL);
		vectorColName = meta.get(IsotonicRegTrainParams.VECTOR_COL);
		featureIndex = meta.get(IsotonicRegTrainParams.FEATURE_INDEX);
		TableSchema dataSchema = getDataSchema();
		if (null == vectorColName) {
			colIdx = TableUtil.findColIndex(dataSchema.getFieldNames(), featureColName);
		} else {
			colIdx = TableUtil.findColIndex(dataSchema.getFieldNames(), vectorColName);
		}
	}

	/**
	 * Get the table schema(includes column names and types) of the calculation result.
	 *
	 * @return the table schema of output Row type data.
	 */
	@Override
	public TableSchema getOutputSchema() {
		String predictResultColName = this.params.get(HasPredictionCol.PREDICTION_COL);
		TableSchema dataSchema = getDataSchema();
		return new TableSchema(
			ArrayUtils.add(dataSchema.getFieldNames(), predictResultColName),
			ArrayUtils.add(dataSchema.getFieldTypes(), Types.DOUBLE())
		);
	}

	/**
	 * Map operation method.
	 *
	 * @param row the input Row type data.
	 * @return one Row type data.
	 * @throws Exception This method may throw exceptions. Throwing
	 *                   an exception will cause the operation to fail.
	 */
	@Override
	public Row map(Row row) throws Exception {
		if (null == row) {
			return null;
		}
		if (null == row.getField(colIdx)) {
			return RowUtil.merge(row, (Object) null);
		}
		//use Binary search method to search for the boundaries.
		double feature = (null == this.vectorColName ? ((Number) row.getField(colIdx)).doubleValue() :
				VectorUtil.getVector(row.getField(colIdx)).get(this.featureIndex));
		int foundIndex = Arrays.binarySearch(modelData.boundaries, feature);
		int insertIndex = -foundIndex - 1;
		double predict;
		if (insertIndex == 0) {
			predict = modelData.values[0];
		} else if (insertIndex == modelData.boundaries.length) {
			predict = modelData.values[modelData.values.length - 1];
		} else if (foundIndex < 0) {
			predict = ScalerUtil.minMaxScaler(feature, modelData.boundaries[insertIndex - 1], modelData.boundaries[insertIndex],
					modelData.values[insertIndex], modelData.values[insertIndex - 1]);
		} else {
			predict = modelData.values[foundIndex];
		}
		return RowUtil.merge(row, predict);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy