All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.alink.operator.batch.feature.OneHotTrainBatchOp Maven / Gradle / Ivy

package com.alibaba.alink.operator.batch.feature;

import java.util.*;

import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.feature.OneHotModelData;
import com.alibaba.alink.operator.common.feature.OneHotModelDataConverter;
import com.alibaba.alink.common.model.ModelParamName;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.feature.OneHotTrainParams;

import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/**
 * One-hot maps a serial of columns of category indices to a column of
 * sparse binary vector. It will produce a model of one hot, and then it can transform
 * data to binary format using this model.
 */
public final class OneHotTrainBatchOp extends BatchOperator
	implements OneHotTrainParams  {

	public final static String DELIMITER = "@ # %";
	private final static String NULL_VALUE = "null";

	/**
	 * null constructor.
	 */
	public OneHotTrainBatchOp() {
		super(null);
	}

	/**
	 * constructor.
	 *
	 * @param params the parameters set.
	 */
	public OneHotTrainBatchOp(Params params) {
		super(params);
	}

	@Override
	public OneHotTrainBatchOp linkFrom(BatchOperator... inputs) {
		BatchOperator in = checkAndGetFirst(inputs);
		// encoding columns names
		String[] selectedColNames = getSelectedCols();

		// the type to processing the NULL values
		boolean ignoreNull = getIgnoreNull();

		// drop the last coding value for nonlinear of kv vectors
		boolean dropLast = getDropLast();

		String[] colNames = in.getColNames();
		int[] idx = new int[selectedColNames.length];
		for (int i = 0; i < selectedColNames.length; ++i) {
			idx[i] = TableUtil.findColIndex(colNames, selectedColNames[i]);
		}

		DataSet  mapping = in.getDataSet()
			.mapPartition(new ParseItem(selectedColNames, idx))
			.reduceGroup(new ReduceItem(dropLast, ignoreNull, selectedColNames))
			.setParallelism(1);

		this.setOutput(mapping, new OneHotModelDataConverter().getModelSchema());
		return this;
	}

	/**
	 * reduce all the items and then build model.
	 */
	public static class ReduceItem implements GroupReduceFunction {
		private boolean dropLast;
		private boolean ignoreNull;
		private String[] selectedColNames;

		public ReduceItem(boolean dropLast, boolean ignoreNull, String[] selectedColNames) {
			this.dropLast = dropLast;
			this.ignoreNull = ignoreNull;
			this.selectedColNames = selectedColNames;
		}

		@Override
		public void reduce(Iterable rows, Collector collector) throws Exception {
			Map > map = new LinkedHashMap <>(0);
			for (Row row : rows) {
				String colName = row.getField(0).toString();
				String value = row.getField(1).toString();
				if (map.containsKey(colName)) {
					map.get(colName).add(value);
				} else {
					HashSet  set = new HashSet <>();
					set.add(value);
					map.put(colName, set);
				}
			}

			// data for model.
			// construct model data.
			ArrayList  data = new ArrayList <>();
			int mapIteration = 0;
			if (ignoreNull) {
				for (String selectedColName : selectedColNames) {
					HashSet cate = map.get(selectedColName);
					for (String val : cate) {
						if (val.equalsIgnoreCase(NULL_VALUE)) {
							continue;
						}
						String tmp = selectedColName + DELIMITER
								+ val + DELIMITER
								+ mapIteration++;
						data.add(tmp);
					}
				}
				if (dropLast) {
					if (!(data.get(data.size() - 1).equalsIgnoreCase(NULL_VALUE))) {
						data.remove(data.size() - 1);
						mapIteration--;
					}
				}
			} else {
				for (String selectedColName : selectedColNames) {
					HashSet cate = map.get(selectedColName);
					for (String val : cate) {
						String tmp = selectedColName + DELIMITER
								+ val + DELIMITER
								+ mapIteration++;
						data.add(tmp);
					}
				}
				if (dropLast) {
					data.remove(data.size() - 1);
					mapIteration--;
				}
			}

			// save model.
			OneHotModelData model = new OneHotModelData();
			model.data = data;
			model.meta.set(ModelParamName.VECTOR_SIZE, mapIteration + 1);

			new OneHotModelDataConverter().save(model, collector);
		}
	}

	/**
	 * Count all the possible values of each column and then keep them.
	 */
	public static class ParseItem implements MapPartitionFunction  {
		private String[] binaryColNames;
		private int[] idx;

		ParseItem(String[] binaryColNames, int[] idx) {
			this.binaryColNames = binaryColNames;
			this.idx = idx;
		}

		@Override
		public void mapPartition(Iterable  rows, Collector  collector) throws Exception {
			Map > map = new HashMap<>(0);
			int m = this.binaryColNames.length;

			for (Row row : rows) {
				for (int i = 0; i < m; i++) {
					String colName = this.binaryColNames[i];
					Object obj = row.getField(idx[i]);
					String value = (obj == null) ? "null" : obj.toString();
					if (map.containsKey(colName)) {
						HashSet  set = map.get(colName);
						set.add(value);
					} else {
						HashSet  set = new HashSet <>();
						set.add(value);
						map.put(colName, set);
					}
				}
			}
			//for each value of one column, keep it and its column value in a row.
			for (Map.Entry > entry : map.entrySet()) {
				String name = entry.getKey();
				HashSet  values = entry.getValue();
				for (String value : values) {
					Row r = new Row(2);
					r.setField(0, name);
					r.setField(1, value);
					collector.collect(r);
				}
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy