All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.github.ericmedvet.jgea.problem.mapper.FitnessFunction Maven / Gradle / Ivy

The newest version!
/*-
 * ========================LICENSE_START=================================
 * jgea-problem
 * %%
 * Copyright (C) 2018 - 2024 Eric Medvet
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =========================LICENSE_END==================================
 */

package io.github.ericmedvet.jgea.problem.mapper;

import io.github.ericmedvet.jgea.core.distance.BitStringHamming;
import io.github.ericmedvet.jgea.core.distance.Distance;
import io.github.ericmedvet.jgea.core.operator.GeneticOperator;
import io.github.ericmedvet.jgea.core.representation.sequence.bit.BitString;
import io.github.ericmedvet.jgea.core.representation.sequence.bit.BitStringFactory;
import io.github.ericmedvet.jgea.core.representation.sequence.bit.BitStringFlipMutation;
import io.github.ericmedvet.jgea.core.representation.tree.Tree;
import io.github.ericmedvet.jgea.core.util.LinkedHashMultiset;
import io.github.ericmedvet.jgea.core.util.Multiset;
import io.github.ericmedvet.jnb.datastructure.Pair;
import java.util.*;
import java.util.function.Function;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;

public class FitnessFunction implements Function, Tree>, List> {

  private static final int EXPRESSIVENESS_DEPTH = 2;
  private final List problems;
  private final int maxMappingDepth;
  private final List properties;
  private final List genotypes;
  private final double[] genotypeDistances;

  public FitnessFunction(
      List problems,
      int genotypeSize,
      int n,
      int maxMappingDepth,
      List properties,
      long seed) {
    this.problems = problems;
    this.maxMappingDepth = maxMappingDepth;
    this.properties = properties;
    Random random = new Random(seed);
    // build genotypes
    GeneticOperator mutation = new BitStringFlipMutation(0.01d);
    BitStringFactory factory = new BitStringFactory(genotypeSize);
    Set set = new LinkedHashSet<>();
    for (int i = 0; i < Math.floor(Math.sqrt(n)); i++) {
      set.addAll(consecutiveMutations(factory.build(random), (int) Math.floor(Math.sqrt(n)), mutation, random));
    }
    while (set.size() < n) {
      set.add(factory.build(random));
    }
    genotypes = new ArrayList<>(set);
    // compute distances
    genotypeDistances = computeDistances(genotypes, new BitStringHamming());
  }

  public enum Property {
    DEGENERACY,
    NON_UNIFORMITY,
    NON_LOCALITY
  }

  @SuppressWarnings({"rawtypes", "unchecked"})
  @Override
  public List apply(Pair, Tree> pair) {
    List> valuesLists = new ArrayList<>();
    for (EnhancedProblem problem : problems) {
      List localValues = apply(pair, problem);
      if (valuesLists.isEmpty()) {
        localValues.forEach(v -> {
          List valuesList = new ArrayList<>(problems.size());
          valuesLists.add(valuesList);
        });
      }
      for (int i = 0; i < localValues.size(); i++) {
        valuesLists.get(i).add(localValues.get(i));
      }
    }
    return valuesLists.stream()
        .map(valuesList -> valuesList.stream()
            .mapToDouble(Double::doubleValue)
            .average()
            .orElse(Double.NaN))
        .toList();
  }

  protected  List apply(Pair, Tree> pair, EnhancedProblem problem) {
    // build mapper
    RecursiveMapper recursiveMapper = new RecursiveMapper<>(
        pair.first(),
        pair.second(),
        maxMappingDepth,
        EXPRESSIVENESS_DEPTH,
        problem.problem().getGrammar());
    // map
    List solutions = genotypes.stream()
        .map(recursiveMapper)
        .map(t -> problem.problem().getSolutionMapper().apply(t))
        .toList();
    Multiset multiset = new LinkedHashMultiset<>(solutions);
    multiset.addAll(solutions);
    // compute properties
    List values = new ArrayList<>();
    for (Property property : properties) {
      if (property.equals(Property.DEGENERACY)) {
        values.add(1d - (double) multiset.elementSet().size() / (double) genotypes.size());
      } else if (property.equals(Property.NON_UNIFORMITY)) {
        double[] sizes = multiset.elementSet().stream()
            .mapToDouble(multiset::count)
            .toArray();
        values.add(Math.sqrt(StatUtils.variance(sizes)) / StatUtils.mean(sizes));
      } else if (property.equals(Property.NON_LOCALITY)) {
        double[] solutionDistances = computeDistances(solutions, problem.distance());
        double locality =
            1d - (1d + (new PearsonsCorrelation().correlation(genotypeDistances, solutionDistances))) / 2d;
        values.add(Double.isNaN(locality) ? 1d : locality);
      } else {
        values.add(0d);
      }
    }
    return values;
  }

  private  double[] computeDistances(List elements, Distance distance) {
    double[] dists = new double[elements.size() * (elements.size() - 1) / 2];
    int c = 0;
    for (int i = 0; i < elements.size() - 1; i++) {
      for (int j = i + 1; j < elements.size(); j++) {
        dists[c] = distance.apply(elements.get(i), elements.get(j));
        c = c + 1;
      }
    }
    return dists;
  }

  private List consecutiveMutations(
      BitString g, int n, GeneticOperator mutation, Random random) {
    Set set = new LinkedHashSet<>();
    while (set.size() < n) {
      set.add(g);
      g = mutation.apply(Collections.singletonList(g), random).getFirst();
    }
    return new ArrayList<>(set);
  }
}