All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.cli.gridsearch.GridSearch Maven / Gradle / Ivy

package com.expleague.ml.cli.gridsearch;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.Func;
import com.expleague.ml.TargetFunc;
import com.expleague.commons.math.Trans;
import com.expleague.ml.cli.builders.methods.MethodsBuilder;
import com.expleague.ml.data.tools.DataTools;
import com.expleague.ml.data.tools.Pool;
import com.expleague.ml.methods.VecOptimization;

import java.util.Arrays;

/**
 * User: qdeee
 * Date: 25.03.15
 */
public class GridSearch {
  private Pool learn;
  private Pool test;
  private MethodsBuilder methodsBuilder;
  private TargetFunc loss;
  private Func[] metrics;

  public GridSearch(
      final Pool learn,
      final Pool test,
      final TargetFunc loss,
      final Func[] metrics,
      final MethodsBuilder methodsBuilder
      ) {
    this.learn = learn;
    this.test = test;
    this.methodsBuilder = methodsBuilder;
    this.loss = loss;
    this.metrics = metrics;
  }

  public OptimumHolder[] search(
      final String commonScheme,
      final Object[][] parametersSpace
  ) {
    final OptimumHolder[] optimumHolders = new OptimumHolder[metrics.length];
    final ParametersGridEnumerator enumerator = new ParametersGridEnumerator<>(parametersSpace);
    while (enumerator.advance()) {
      final Object[] parameters = enumerator.getParameters();
      System.out.println(Arrays.toString(parameters));
      final String concreteScheme = String.format(commonScheme, parameters);
      final VecOptimization method = methodsBuilder.create(concreteScheme);
      final Trans result = method.fit(learn.vecData(), loss);
      final double targetValue = loss.value(DataTools.calcAll(result, learn.vecData()));
      final double[] metricsValues = new double[this.metrics.length];
      final Vec testEvaluation = DataTools.calcAll(result, test.vecData());
      for (int i = 0; i < metricsValues.length; i++) {
        metricsValues[i] = metrics[i].value(testEvaluation);
      }
      final OptimumHolder currentHolder = new OptimumHolder(parameters, metricsValues, targetValue);
      for (int i = 0; i < optimumHolders.length; i++) {
        if (currentHolder.isBetterThan(optimumHolders[i], i)) {
          optimumHolders[i] = currentHolder;
        }
      }
    }
    return optimumHolders;
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy