All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.alibaba.alink.operator.common.dataproc.MultiStringIndexerModelMapper Maven / Gradle / Ivy

package com.alibaba.alink.operator.common.dataproc;


import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dataproc.MultiStringIndexerPredictParams;
import com.alibaba.alink.params.shared.colname.HasSelectedCols;
import com.alibaba.alink.pipeline.dataproc.MultiStringIndexerModel;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * The model mapper for {@link MultiStringIndexerModel}.
 */
public class MultiStringIndexerModelMapper extends ModelMapper {

    /**
     * A list maps that map token to index. Each map in the list corresponds to a column.
     */
    private transient Map> indexMapper;

    /**
     * The default index for unseen tokens of each columns.
     */
    private transient Map defaultIndex;

    private OutputColsHelper outputColsHelper;
    private String[] selectedColNames;
    private int[] selectedColIndicesInData;
    private StringIndexerUtil.HandleInvalidStrategy handleInvalidStrategy;

    public MultiStringIndexerModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
        super(modelSchema, dataSchema, params);
        selectedColNames = params.get(MultiStringIndexerPredictParams.SELECTED_COLS);
        String[] outputColNames = params.get(MultiStringIndexerPredictParams.OUTPUT_COLS);
        if (outputColNames == null) {
            outputColNames = selectedColNames;
        }
        String[] reservedColNames = params.get(MultiStringIndexerPredictParams.RESERVED_COLS);

        handleInvalidStrategy = StringIndexerUtil.HandleInvalidStrategy
            .valueOf(params.get(MultiStringIndexerPredictParams.HANDLE_INVALID).toUpperCase());

        TypeInformation[] outputColTypes = new TypeInformation[selectedColNames.length];
        Arrays.fill(outputColTypes, Types.LONG);

        outputColsHelper = new OutputColsHelper(dataSchema, outputColNames, outputColTypes, reservedColNames);
    }

    @Override
    public void loadModel(List modelRows) {
        MultiStringIndexerModelData model = new MultiStringIndexerModelDataConverter().load(modelRows);

        String[] trainColNames = model.meta.get(HasSelectedCols.SELECTED_COLS);
        this.selectedColIndicesInData = TableUtil.findColIndices(super.getDataSchema(), selectedColNames);
        int[] selectedColIndicesInModel = TableUtil.findColIndices(trainColNames, selectedColNames);
        this.indexMapper = new HashMap<>();
        this.defaultIndex = new HashMap<>();

        for (int i = 0; i < selectedColNames.length; i++) {
            Map mapper = new HashMap<>();
            int colIdxInModel = selectedColIndicesInModel[i];
            for (Tuple3 record : model.tokenAndIndex) {
                if (record.f0 == colIdxInModel) {
                    String token = record.f1;
                    Long index = record.f2;
                    mapper.put(token, index);
                }
            }
            indexMapper.put(i, mapper);
            Long defaultIdx = model.tokenNumber.get(colIdxInModel);
            defaultIndex.put(i, defaultIdx);
        }
    }

    @Override
    public TableSchema getOutputSchema() {
        return outputColsHelper.getResultSchema();
    }

    @Override
    public Row map(Row row) throws Exception {
        Row result = new Row(selectedColNames.length);
        for (int i = 0; i < selectedColNames.length; i++) {
            Map mapper = indexMapper.get(i);
            int colIdxInData = selectedColIndicesInData[i];
            Object val = row.getField(colIdxInData);
            String key = val == null ? null : String.valueOf(val);
            Long index = mapper.get(key);
            if (index != null) {
                result.setField(i, index);
            } else {
                switch (this.handleInvalidStrategy) {
                    case KEEP:
                        Long localDefaultIndex = defaultIndex.get(i);
                        result.setField(i, localDefaultIndex);
                        break;
                    case SKIP:
                        result.setField(i, null);
                        break;
                    case ERROR:
                        throw new RuntimeException("Unseen token: " + key);
                    default:
                        throw new IllegalArgumentException("Invalid handle invalid strategy.");
                }
            }
        }
        return outputColsHelper.getResultRow(row, result);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy