
org.kie.pmml.models.regression.model.KiePMMLRegressionClassificationTable Maven / Gradle / Ivy
/*
* Copyright 2020 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.kie.pmml.models.regression.model;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.DoubleUnaryOperator;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.kie.pmml.api.enums.OP_TYPE;
import org.kie.pmml.models.regression.model.enums.REGRESSION_NORMALIZATION_METHOD;
import static org.kie.pmml.commons.Constants.EXPECTED_TWO_ENTRIES_RETRIEVED;
public abstract class KiePMMLRegressionClassificationTable extends KiePMMLRegressionTable {
protected REGRESSION_NORMALIZATION_METHOD regressionNormalizationMethod;
protected OP_TYPE opType;
protected Map categoryTableMap = new LinkedHashMap<>(); // Insertion order matters
@Override
public Object evaluateRegression(Map input) {
final LinkedHashMap resultMap = new LinkedHashMap<>();
for (Map.Entry entry : categoryTableMap.entrySet()) {
resultMap.put(entry.getKey(), (Double) entry.getValue().evaluateRegression(input));
}
final LinkedHashMap probabilityMap = getProbabilityMap(resultMap);
final Map.Entry predictedEntry = Collections.max(probabilityMap.entrySet(), Map.Entry.comparingByValue());
probabilityMap.put(targetField, predictedEntry.getValue());
populateOutputFieldsMapWithResult(predictedEntry.getKey());
populateOutputFieldsMapWithProbability(predictedEntry, probabilityMap);
return predictedEntry.getKey();
}
/**
* A Classification is considered binary if it is of CATEGORICAL type and contains exactly two Regression tables
* @return
*/
public abstract boolean isBinary();
protected abstract LinkedHashMap getProbabilityMap(final LinkedHashMap resultMap);
protected abstract void populateOutputFieldsMapWithProbability(final Map.Entry predictedEntry, final LinkedHashMap probabilityMap);
protected void updateResult(final AtomicReference toUpdate) {
// NOOP
}
public REGRESSION_NORMALIZATION_METHOD getRegressionNormalizationMethod() {
return regressionNormalizationMethod;
}
public OP_TYPE getOpType() {
return opType;
}
public Map getCategoryTableMap() {
return categoryTableMap;
}
protected LinkedHashMap getProbabilityMap(final LinkedHashMap resultMap, DoubleUnaryOperator firstItemOperator, DoubleUnaryOperator secondItemOperator) {
if (resultMap.size() != 2) {
throw new KiePMMLException(String.format(EXPECTED_TWO_ENTRIES_RETRIEVED, resultMap.size()));
}
LinkedHashMap toReturn = new LinkedHashMap<>();
String[] resultMapKeys = resultMap.keySet().toArray(new String[0]);
double firstItem = firstItemOperator.applyAsDouble(resultMap.get(resultMapKeys[0]));
double secondItem = secondItemOperator.applyAsDouble(firstItem);
toReturn.put(resultMapKeys[0], firstItem);
toReturn.put(resultMapKeys[1], secondItem);
return toReturn;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy