All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.selection.cv.CrossValidation Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.selection.cv;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;

/**
 * Cross validation score calculator. Cross validation is an approach that allows to avoid overfitting that is made the
 * following way: the training set is split into k smaller sets. The following procedure is followed for each of the k
 * “folds”:
 * 
    *
  • A model is trained using k-1 of the folds as training data;
  • *
  • the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute * a performance measure such as accuracy).
  • *
* * @param Type of model. * @param Type of a label (truth or prediction). * @param Type of a key in {@code upstream} data. * @param Type of a value in {@code upstream} data. */ public class CrossValidation, L, K, V> { /** * Computes cross-validated metrics. * * @param trainer Trainer of the model. * @param scoreCalculator Score calculator. * @param ignite Ignite instance. * @param upstreamCache Ignite cache with {@code upstream} data. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer trainer, Metric scoreCalculator, Ignite ignite, IgniteCache upstreamCache, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, int cv) { return score(trainer, scoreCalculator, ignite, upstreamCache, (k, v) -> true, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } /** * Computes cross-validated metrics. * * @param trainer Trainer of the model. * @param scoreCalculator Base score calculator. * @param ignite Ignite instance. * @param upstreamCache Ignite cache with {@code upstream} data. * @param filter Base {@code upstream} data filter. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer trainer, Metric scoreCalculator, Ignite ignite, IgniteCache upstreamCache, IgniteBiPredicate filter, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, int cv) { return score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } /** * Computes cross-validated metrics with a passed parameter grid. * * The real cross-validation training will be called each time for each parameter set. * * @param trainer Trainer of the model. * @param scoreCalculator Base score calculator. * @param ignite Ignite instance. * @param upstreamCache Ignite cache with {@code upstream} data. * @param filter Base {@code upstream} data filter. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param cv Number of folds. * @param paramGrid Parameter grid. * @return Array of scores of the estimator for each run of the cross validation. */ public CrossValidationResult score(DatasetTrainer trainer, Metric scoreCalculator, Ignite ignite, IgniteCache upstreamCache, IgniteBiPredicate filter, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, int cv, ParamGrid paramGrid) { List paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate(); CrossValidationResult cvRes = new CrossValidationResult(); paramSets.forEach(paramSet -> { Map paramMap = new HashMap<>(); for (int paramIdx = 0; paramIdx < paramSet.length; paramIdx++) { String paramName = paramGrid.getParamNameByIndex(paramIdx); Double paramVal = paramSet[paramIdx]; paramMap.put(paramName, paramVal); try { final String mtdName = "with" + paramName.substring(0, 1).toUpperCase() + paramName.substring(1); Method trainerSetter = null; // We should iterate along all methods due to we have no info about signature and passed types. for (Method method : trainer.getClass().getDeclaredMethods()) { if (method.getName().equals(mtdName)) trainerSetter = method; } if (trainerSetter != null) trainerSetter.invoke(trainer, paramVal); else throw new NoSuchMethodException(mtdName); } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { e.printStackTrace(); } } double[] locScores = score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); cvRes.addScores(locScores, paramMap); final double locAvgScore = Arrays.stream(locScores).average().orElse(Double.MIN_VALUE); if (locAvgScore > cvRes.getBestAvgScore()) { cvRes.setBestScore(locScores); cvRes.setBestHyperParams(paramMap); System.out.println(paramMap.toString()); } }); return cvRes; } /** * Computes cross-validated metrics. * * @param trainer Trainer of the model. * @param scoreCalculator Base score calculator. * @param ignite Ignite instance. * @param upstreamCache Ignite cache with {@code upstream} data. * @param filter Base {@code upstream} data filter. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer trainer, Metric scoreCalculator, Ignite ignite, IgniteCache upstreamCache, IgniteBiPredicate filter, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, UniformMapper mapper, int cv) { return score( trainer, predicate -> new CacheBasedDatasetBuilder<>( ignite, upstreamCache, (k, v) -> filter.apply(k, v) && predicate.apply(k, v) ), (predicate, mdl) -> new CacheBasedLabelPairCursor<>( upstreamCache, (k, v) -> filter.apply(k, v) && !predicate.apply(k, v), featureExtractor, lbExtractor, mdl ), featureExtractor, lbExtractor, scoreCalculator, mapper, cv ); } /** * Computes cross-validated metrics. * * @param trainer Trainer of the model. * @param scoreCalculator Base score calculator. * @param upstreamMap Map with {@code upstream} data. * @param parts Number of partitions. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer trainer, Metric scoreCalculator, Map upstreamMap, int parts, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, int cv) { return score(trainer, scoreCalculator, upstreamMap, (k, v) -> true, parts, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } /** * Computes cross-validated metrics. * * @param trainer Trainer of the model. * @param scoreCalculator Base score calculator. * @param upstreamMap Map with {@code upstream} data. * @param filter Base {@code upstream} data filter. * @param parts Number of partitions. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer trainer, Metric scoreCalculator, Map upstreamMap, IgniteBiPredicate filter, int parts, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, int cv) { return score(trainer, scoreCalculator, upstreamMap, filter, parts, featureExtractor, lbExtractor, new SHA256UniformMapper<>(), cv); } /** * Computes cross-validated metrics. * * @param trainer Trainer of the model. * @param scoreCalculator Base score calculator. * @param upstreamMap Map with {@code upstream} data. * @param filter Base {@code upstream} data filter. * @param parts Number of partitions. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ public double[] score(DatasetTrainer trainer, Metric scoreCalculator, Map upstreamMap, IgniteBiPredicate filter, int parts, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, UniformMapper mapper, int cv) { return score( trainer, predicate -> new LocalDatasetBuilder<>( upstreamMap, (k, v) -> filter.apply(k, v) && predicate.apply(k, v), parts ), (predicate, mdl) -> new LocalLabelPairCursor<>( upstreamMap, (k, v) -> filter.apply(k, v) && !predicate.apply(k, v), featureExtractor, lbExtractor, mdl ), featureExtractor, lbExtractor, scoreCalculator, mapper, cv ); } /** * Computes cross-validated metrics. * * @param trainer Trainer of the model. * @param datasetBuilderSupplier Dataset builder supplier. * @param testDataIterSupplier Test data iterator supplier. * @param featureExtractor Feature extractor. * @param lbExtractor Label extractor. * @param scoreCalculator Base score calculator. * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). * @param cv Number of folds. * @return Array of scores of the estimator for each run of the cross validation. */ private double[] score(DatasetTrainer trainer, Function, DatasetBuilder> datasetBuilderSupplier, BiFunction, M, LabelPairCursor> testDataIterSupplier, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, Metric scoreCalculator, UniformMapper mapper, int cv) { double[] scores = new double[cv]; double foldSize = 1.0 / cv; for (int i = 0; i < cv; i++) { double from = foldSize * i; double to = foldSize * (i + 1); IgniteBiPredicate trainSetFilter = (k, v) -> { double pnt = mapper.map(k, v); return pnt < from || pnt > to; }; DatasetBuilder datasetBuilder = datasetBuilderSupplier.apply(trainSetFilter); M mdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor); try (LabelPairCursor cursor = testDataIterSupplier.apply(trainSetFilter, mdl)) { scores[i] = scoreCalculator.score(cursor.iterator()); } catch (Exception e) { throw new RuntimeException(e); } } return scores; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy