All Downloads are FREE. Search and download functionalities are using the official Maven repository.

smile.hpo.Hyperparameters Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
 *
 * Smile is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Smile is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Smile.  If not, see .
 */

package smile.hpo;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Stream;
import smile.math.MathEx;

/**
 * Hyperparameter configuration. A hyperparameter is a parameter whose value is set
 * before the learning process begins. By contrast, the values of other
 * parameters are derived via training.
 * 

* Hyperparameters can be classified as model hyperparameters, that cannot * be inferred while fitting the machine to the training set because they * refer to the model selection task, or algorithm hyperparameters, that * in principle have no influence on the performance of the model but * affect the speed and quality of the learning process. For example, * the topology and size of a neural network are model hyperparameters, * while learning rate and mini-batch size are algorithm hyperparameters. *

* The below example shows how to tune the hyperparameters of random forest. *

 * {@code
 *    import smile.io.*;
 *    import smile.data.formula.Formula;
 *    import smile.validation.*;
 *    import smile.classification.RandomForest;
 *
 *    var hp = new Hyperparameters()
 *        .add("smile.random.forest.trees", 100) // a fixed value
 *        .add("smile.random.forest.mtry", new int[] {2, 3, 4}) // an array of values to choose
 *        .add("smile.random.forest.max.nodes", 100, 500, 50); // range [100, 500] with step 50
 *
 *    var train = Read.arff("data/weka/segment-challenge.arff");
 *    var test = Read.arff("data/weka/segment-test.arff");
 *    var formula = Formula.lhs("class");
 *    var testy = formula.y(test).toIntArray();
 *
 *    hp.grid().forEach(prop -> {
 *        var model = RandomForest.fit(formula, train, prop);
 *        var pred = model.predict(test);
 *        System.out.println(prop);
 *        System.out.format("Accuracy = %.2f%%%n", (100.0 * Accuracy.of(testy, pred)));
 *        System.out.println(ConfusionMatrix.of(testy, pred));
 *    });
 * }
 * 
* @author Haifeng Li */ public class Hyperparameters { /** The set of parameters. */ private final HashMap parameters = new HashMap<>(); record KeyValue(String key, String value) { } record IntRange(int start, int end, int step) { public IntRange { if (start >= end) { throw new IllegalArgumentException(String.format("start = %d, end = %d", start, end)); } if (step <= 0) { throw new IllegalArgumentException("step = " + step); } } public IntRange(int start, int end) { this(start, end, Math.max(1, (end-start)/10)); } int[] toArray() { int n = (end - start) / step + 1; int[] a = new int[n]; a[0] = start; for (int i = 1; i < n; i++) { a[i] = a[i-1] + step; } return a; } } record DoubleRange(double start, double end, double step) { public DoubleRange { if (start >= end) { throw new IllegalArgumentException(String.format("start = %f, end = %f", start, end)); } if (step <= 0.0) { throw new IllegalArgumentException("step = " + step); } } public DoubleRange(double start, double end) { this(start, end, (end-start)/10); } double[] toArray() { double intervals = (end - start) / step; int n = (int) Math.ceil(intervals); if (intervals == n) n++; double[] a = new double[n]; a[0] = start; for (int i = 1; i < n; i++) { a[i] = a[i-1] + step; } return a; } } /** Constructor. */ public Hyperparameters() { } /** * Adds a parameter. * * @param name the parameter name. * @param value a fixed value of parameter. * @return this object. */ public Hyperparameters add(String name, int value) { return add(name, new int[] {value}); } /** * Adds a parameter. * * @param name the parameter name. * @param value a fixed value of parameter. * @return this object. */ public Hyperparameters add(String name, double value) { return add(name, new double[] {value}); } /** * Adds a parameter. * * @param name the parameter name. * @param value a fixed value of parameter. * @return this object. */ public Hyperparameters add(String name, String value) { return add(name, new String[] {value}); } /** * Adds a parameter. * * @param name the parameter name. * @param values an array of parameter values. * @return this object. */ public Hyperparameters add(String name, int[] values) { if (values.length == 0) { throw new IllegalArgumentException("Empty array"); } parameters.put(name, values); return this; } /** * Adds a parameter. * * @param name the parameter name. * @param values an array of parameter values. * @return this object. */ public Hyperparameters add(String name, double[] values) { if (values.length == 0) { throw new IllegalArgumentException("Empty array"); } parameters.put(name, values); return this; } /** * Adds a parameter. * * @param name the parameter name. * @param values an array of parameter values. * @return this object. */ public Hyperparameters add(String name, String[] values) { if (values.length == 0) { throw new IllegalArgumentException("Empty array"); } parameters.put(name, values); return this; } /** * Adds a parameter. * * @param name the parameter name. * @param start the start of value range (inclusive). * @param end the end of value range (inclusive). * @return this object. */ public Hyperparameters add(String name, int start, int end) { parameters.put(name, new IntRange(start, end)); return this; } /** * Adds a parameter. * * @param name the parameter name. * @param start the start of value range (inclusive). * @param end the end of value range (inclusive). * @param step the step size. * @return this object. */ public Hyperparameters add(String name, int start, int end, int step) { parameters.put(name, new IntRange(start, end, step)); return this; } /** * Adds a parameter. * * @param name the parameter name. * @param start the start of value range (inclusive). * @param end the end of value range (inclusive). * @return this object. */ public Hyperparameters add(String name, double start, double end) { parameters.put(name, new DoubleRange(start, end)); return this; } /** * Adds a parameter. * * @param name the parameter name. * @param start the start of value range (inclusive). * @param end the end of value range (inclusive). * @param step the step size. * @return this object. */ public Hyperparameters add(String name, double start, double end, double step) { parameters.put(name, new DoubleRange(start, end, step)); return this; } /** * Generates a stream of hyperparameters for random search. * @return the stream of hyperparameters for random search. */ public Stream random() { return Stream.generate(() -> { Properties params = new Properties(); parameters.forEach((name, values) -> { switch (values) { case int[] a -> { int v = a.length == 1 ? a[0] : a[MathEx.randomInt(a.length)]; params.setProperty(name, String.valueOf(v)); } case double[] a -> { double v = a.length == 1 ? a[0] : a[MathEx.randomInt(a.length)]; params.setProperty(name, String.valueOf(v)); } case String[] a -> { String v = a.length == 1 ? a[0] : a[MathEx.randomInt(a.length)]; params.setProperty(name, v); } case IntRange range -> params.setProperty(name, String.valueOf(MathEx.randomInt(range.start, range.end))); case DoubleRange range -> params.setProperty(name, String.valueOf(MathEx.random(range.start, range.end))); case null, default -> throw new IllegalStateException("Unknown parameter type: " + values); } }); return params; }); } /** * Generates a stream of hyperparameters for grid search. * @return the stream of hyperparameters for grid search. */ public Stream grid() { ArrayList> lists = new ArrayList<>(parameters.entrySet()); // Extract each value of first parameter and add each to a new Properties. ArrayList> combinations = new ArrayList<>(); for(var pair : values(lists.getFirst())) { ArrayList newList = new ArrayList<>(); newList.add(pair); combinations.add(newList); } for(int i = 1; i < lists.size(); i++) { ArrayList nextList = values(lists.get(i)); ArrayList> newCombinations = new ArrayList<>(); for(var first : combinations) { for(var second: nextList) { ArrayList newList = new ArrayList<>(first); newList.add(second); newCombinations.add(newList); } } combinations = newCombinations; } return combinations.stream().map(list -> { Properties params = new Properties(); list.forEach(p -> params.setProperty(p.key(), p.value())); return params; }); } /** Returns the list of parameter values. */ private ArrayList values(Map.Entry parameter) { ArrayList list = new ArrayList<>(); String name = parameter.getKey(); Object values = parameter.getValue(); switch (values) { case int[] array -> { for (int value : array) { list.add(new KeyValue(name, String.valueOf(value))); } } case double[] array -> { for (double value : array) { list.add(new KeyValue(name, String.valueOf(value))); } } case String[] array -> { for (String value : array) { list.add(new KeyValue(name, String.valueOf(value))); } } case IntRange range -> { for (int value : range.toArray()) { list.add(new KeyValue(name, String.valueOf(value))); } } case DoubleRange range -> { for (double value : range.toArray()) { list.add(new KeyValue(name, String.valueOf(value))); } } case null, default -> throw new IllegalStateException("Unknown parameter type: " + values); } return list; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy