All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.rinde.rinsim.util.StochasticSuppliers Maven / Gradle / Ivy

There is a newer version: 4.4.6
Show newest version
/*
 * Copyright (C) 2011-2016 Rinde van Lon, iMinds-DistriNet, KU Leuven
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.github.rinde.rinsim.util;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;

import java.io.Serializable;
import java.math.RoundingMode;
import java.util.Iterator;

import javax.annotation.Nonnull;

import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.MersenneTwister;

import com.google.auto.value.AutoValue;
import com.google.common.base.Predicate;
import com.google.common.base.Supplier;
import com.google.common.math.DoubleMath;
import com.google.common.primitives.Doubles;
import com.google.common.reflect.TypeToken;

/**
 * Utility class for {@link StochasticSupplier}.
 * @author Rinde van Lon
 */
public final class StochasticSuppliers {

  private StochasticSuppliers() {}

  /**
   * Create a {@link StochasticSupplier} that will always return the specified
   * value.
   * @param value The value which the supplier will return.
   * @param  Type of constant.
   * @return A supplier that always returns the specified value.
   */
  public static  StochasticSupplier constant(T value) {
    return new ConstantSupplier<>(value);
  }

  /**
   * Checks whether the provided supplier is a constant created by
   * {@link #constant(Object)}.
   * @param supplier The supplier to check.
   * @param  The type this supplier generates.
   * @return true if the provided supplier is created by
   *         {@link #constant(Object)}, false otherwise.
   */
  public static  boolean isConstant(StochasticSupplier supplier) {
    return supplier.getClass() == ConstantSupplier.class;
  }

  /**
   * Creates a {@link StochasticSupplier} that will always throw an
   * {@link IllegalArgumentException} with the specified errorMsg.
   * This can be useful when a default 'empty' supplier is needed.
   * @param errorMsg The error message of the exception.
   * @param  The type this supplier generates.
   * @return A supplier that always throws an exception.
   */
  public static  StochasticSupplier empty(String errorMsg) {
    return new EmptySupplier<>(errorMsg);
  }

  /**
   * Decorates the specified {@link StochasticSupplier} such that when it
   * produces values which are not allowed by the specified predicate an
   * {@link IllegalArgumentException} is thrown.
   * @param supplier The supplier to be decorated.
   * @param predicate The predicate which specifies the contract to which the
   *          supplier should adhere.
   * @param  The type this supplier generates.
   * @return A supplier that is guaranteed to return values which match the
   *         given predicate or throw an {@link IllegalArgumentException}.
   */
  public static  StochasticSupplier checked(
      StochasticSupplier supplier,
      Predicate predicate) {
    return new CheckedSupplier<>(supplier, predicate);
  }

  /**
   * Create a {@link StochasticSupplier} based on an {@link Iterable}. It will
   * return the values in the order as defined by the iterable. The resulting
   * supplier will throw an {@link IllegalArgumentException} when the iterable
   * is empty.
   * @param iter The iterable from which the values will be used.
   * @param  The type this supplier generates.
   * @return A supplier based on an iterable.
   */
  public static  StochasticSupplier fromIterable(Iterable iter) {
    return new IteratorSS<>(iter.iterator());
  }

  /**
   * Create a {@link StochasticSupplier} based on a {@link Supplier}.
   * @param supplier The supplier to adapt.
   * @param  The type this supplier generates.
   * @return The adapted supplier.
   */
  public static  StochasticSupplier fromSupplier(Supplier supplier) {
    return new SupplierAdapter<>(supplier);
  }

  /**
   * @return Builder for constructing {@link StochasticSupplier}s that produce
   *         normal (Gaussian) distributed numbers.
   */
  public static Builder normal() {
    return new Builder();
  }

  /**
   * Creates a {@link StochasticSupplier} that produces uniformly distributed
   * {@link Double}s.
   * @param lower The (inclusive) lower bound of the uniform distribution.
   * @param upper The (inclusive) upper bound of the uniform distribution.
   * @return The supplier.
   */
  public static StochasticSupplier uniformDouble(double lower,
      double upper) {
    return new DoubleDistributionSS(new UniformRealDistribution(
      new MersenneTwister(), lower, upper));
  }

  /**
   * Creates a {@link StochasticSupplier} that produces uniformly distributed
   * {@link Integer}s.
   * @param lower The (inclusive) lower bound of the uniform distribution.
   * @param upper The (inclusive) upper bound of the uniform distribution.
   * @return The supplier.
   */
  public static StochasticSupplier uniformInt(int lower, int upper) {
    return new IntegerDistributionSS(new UniformIntegerDistribution(
      new MersenneTwister(), lower, upper));
  }

  /**
   * Creates a {@link StochasticSupplier} that produces uniformly distributed
   * {@link Long}s.
   * @param lower The (inclusive) lower bound of the uniform distribution.
   * @param upper The (inclusive) upper bound of the uniform distribution.
   * @return The supplier.
   */
  public static StochasticSupplier uniformLong(int lower, int upper) {
    return intToLong(uniformInt(lower, upper));
  }

  /**
   * Convert a {@link StochasticSupplier} of {@link Integer} to a supplier of
   * {@link Long}.
   * @param supplier The supplier to convert.
   * @return The converted supplier.
   */
  public static StochasticSupplier intToLong(
      StochasticSupplier supplier) {
    return new IntToLongAdapter(supplier);
  }

  /**
   * Convert a {@link StochasticSupplier} of {@link Double} to a supplier of
   * {@link Integer}.
   * @param supplier The supplier to convert.
   * @return The converted supplier.
   */
  public static StochasticSupplier roundDoubleToInt(
      StochasticSupplier supplier) {
    return new DoubleToIntAdapter(supplier);
  }

  /**
   * Convert a {@link StochasticSupplier} of {@link Double} to a supplier of
   * {@link Long}.
   * @param supplier The supplier to convert.
   * @return The converted supplier.
   */
  public static StochasticSupplier roundDoubleToLong(
      StochasticSupplier supplier) {
    return new DoubleToLongAdapter(supplier);
  }

  /**
   * @return A {@link StochasticSupplier} of {@link MersenneTwister}.
   */
  public static StochasticSupplier mersenneTwister() {
    return MersenneTwisterSS.create();
  }

  /**
   * Builder for creating {@link StochasticSupplier}s that return a number with
   * a normal distribution.
   * @author Rinde van Lon
   */
  public static class Builder {
    static final double SMALLEST_DOUBLE = 0.000000000000001;
    static final int MAX_ITERATIONS = 1000000;
    static final double STEP_SIZE_DENOMINATOR = 1.5d;
    private double mean;
    private double std;
    private double lowerBound;
    private double upperBound;
    private OutOfBoundStrategy outOfBoundStrategy;

    Builder() {
      mean = 0;
      std = 1;
      lowerBound = Double.NEGATIVE_INFINITY;
      upperBound = Double.POSITIVE_INFINITY;
      outOfBoundStrategy = OutOfBoundStrategy.REDRAW;
    }

    /**
     * Set the mean of the normal distribution.
     * @param m The mean. Default value: 0.
     * @return This, as per the builder pattern.
     */
    public Builder mean(double m) {
      mean = m;
      return this;
    }

    /**
     *
     * @param sd The standard deviation. Default value: 1.
     * @return This, as per the builder pattern.
     */
    public Builder std(double sd) {
      std = sd;
      return this;
    }

    /**
     *
     * @param var The variance. Default value: 1.
     * @return This, as per the builder pattern.
     */
    public Builder variance(double var) {
      std = Math.sqrt(var);
      return this;
    }

    /**
     * Truncates the normal distribution using the lower and upper bounds. In
     * case a number is drawn outside the bounds:
     *  x < lower || x > upper the out of bound strategy
     * defines what will happen. See {@link #redrawWhenOutOfBounds()} and
     * {@link #roundWhenOutOfBounds()} for the options. Note that calling this
     * method may change the effective mean and standard deviation of the normal
     * distribution. If this is undesired you can choose to scale the mean of
     * the distribution, see {@link #scaleMean()} for more details.
     * @param lower The lower bound. Default value:
     *          {@link Double#NEGATIVE_INFINITY}.
     * @param upper The upper bound. Default value:
     *          {@link Double#POSITIVE_INFINITY}.
     * @return This, as per the builder pattern.
     */
    public Builder bounds(double lower, double upper) {
      lowerBound = lower;
      upperBound = upper;
      return this;
    }

    /**
     * Sets the lower bound, see {@link #bounds(double, double)} for more
     * information.
     * @param lower The lower bound.
     * @return This, as per the builder pattern.
     */
    public Builder lowerBound(double lower) {
      lowerBound = lower;
      return this;
    }

    /**
     * Sets the upper bound, see {@link #bounds(double, double)} for more
     * information.
     * @param upper The upper bound.
     * @return This, as per the builder pattern.
     */
    public Builder upperBound(double upper) {
      upperBound = upper;
      return this;
    }

    /**
     * Calling this method will set the out of bounds strategy to redraw. This
     * means that when a number is drawn from the distribution that is out of
     * bounds a new number will be drawn. This will continue until a value is
     * found within bounds. Note that when the bounds are small relative to the
     * distribution this may result in a large number of attempts. By default
     * this strategy is enabled.
     * @return This, as per the builder pattern.
     */
    public Builder redrawWhenOutOfBounds() {
      outOfBoundStrategy = OutOfBoundStrategy.REDRAW;
      return this;
    }

    /**
     * Calling this method will set the out of bounds strategy to redraw. This
     * means that when a number is drawn from the distribution that is out of
     * bounds the number will be rounded to the nearest bound. By default this
     * strategy is disabled.
     * @return This, as per the builder pattern.
     */
    public Builder roundWhenOutOfBounds() {
      outOfBoundStrategy = OutOfBoundStrategy.ROUND;
      return this;
    }

    /**
     * Scale the normal distribution such that the effective mean is as given by
     * {@link #mean(double)} in case a lower bound was set. This method can only
     * be called if the following requirements are met:
     * 
    *
  • Lower bound must be set
  • *
  • Out of bound strategy: {@link #redrawWhenOutOfBounds()}.
  • *
* Note that this method overwrites any previous (but not subsequent) calls * to {@link #mean(double)}. If after calling this method the bounds and/or * out of bound strategy are changed this may yield unexpected results. * Also, using an upper bound is currently not supported. *

* For more information about how the effective mean of the truncated normal * distribution is calculated, see this * Wikipedia article. * @return This, as per the builder pattern. */ public Builder scaleMean() { checkArgument(!Double.isInfinite(lowerBound), "A lower bound must be set in order to scale the mean."); checkArgument(Double.isInfinite(upperBound), "Scaling the mean with an upper bound is currently not supported."); checkArgument(OutOfBoundStrategy.REDRAW == outOfBoundStrategy); double stepSize = 1; double curMean = mean; double dir = 0; double effectiveMean; int iterations = 0; do { effectiveMean = computeEffectiveMean(curMean, std, lowerBound); // save direction final double oldDir = dir; if (effectiveMean > mean) { dir = 1d; } else { dir = -1d; } // if direction changed decrease step size if (dir != oldDir && oldDir != 0) { stepSize /= STEP_SIZE_DENOMINATOR; } // apply step if (effectiveMean > mean) { curMean -= stepSize; } else { curMean += stepSize; } iterations++; checkState(iterations < MAX_ITERATIONS, "Could not converge. Target mean: %s, effective mean: %s.", mean, effectiveMean); } while (Math.abs(effectiveMean - mean) > SMALLEST_DOUBLE); mean = curMean; return this; } /* * Computes effective mean using * https://en.wikipedia.org/wiki/Truncated_normal_distribution#Moments . */ private static double computeEffectiveMean(double m, double s, double lb) { final NormalDistribution normal = new NormalDistribution(); final double alpha = (lb - m) / s; final double pdf = normal.density(alpha); final double cdf = normal.cumulativeProbability(alpha); final double lambda = pdf / (1 - cdf); return m + s * lambda; } /** * @return A {@link StochasticSupplier} that draws double values from a * normal distribution. */ public StochasticSupplier buildDouble() { checkArgument(mean + std >= lowerBound); checkArgument(mean + std <= upperBound); final RealDistribution distribution = new NormalDistribution(mean, std); if (Doubles.isFinite(lowerBound) || Doubles.isFinite(upperBound)) { return new BoundedDoubleDistSS(distribution, upperBound, lowerBound, outOfBoundStrategy); } return new DoubleDistributionSS(distribution); } /** * @return A {@link StochasticSupplier} that draws integer values from a * normal distribution. */ public StochasticSupplier buildInteger() { integerChecks(); return roundDoubleToInt(buildDouble()); } /** * @return A {@link StochasticSupplier} that draws long values from a normal * distribution. */ public StochasticSupplier buildLong() { integerChecks(); return roundDoubleToLong(buildDouble()); } void integerChecks() { checkArgument(Double.isInfinite(lowerBound) || DoubleMath.isMathematicalInteger(lowerBound)); checkArgument(Double.isInfinite(upperBound) || DoubleMath.isMathematicalInteger(upperBound)); } } /** * Abstract implementation providing a default {@link #toString()} * implementation. * @author Rinde van Lon * @param The type of objects that this supplier creates. */ public abstract static class AbstractStochasticSupplier implements StochasticSupplier, Serializable { private static final long serialVersionUID = 992219257352250656L; @Override public String toString() { return new TypeToken(getClass()) { private static final long serialVersionUID = 4641163444574558674L; }.getRawType().getSimpleName() + "Supplier"; } } enum OutOfBoundStrategy { ROUND, REDRAW } private static class IntToLongAdapter extends AbstractStochasticSupplier { private static final long serialVersionUID = 3638307177262422449L; private final StochasticSupplier supplier; IntToLongAdapter(StochasticSupplier supp) { supplier = supp; } @Override public Long get(long seed) { return Long.valueOf(supplier.get(seed)); } } private static class DoubleToIntAdapter extends AbstractStochasticSupplier { private static final long serialVersionUID = 3086452659883375531L; private final StochasticSupplier supplier; DoubleToIntAdapter(StochasticSupplier supp) { supplier = supp; } @Override public Integer get(long seed) { return DoubleMath.roundToInt(supplier.get(seed), RoundingMode.HALF_UP); } } private static class DoubleToLongAdapter extends AbstractStochasticSupplier { private static final long serialVersionUID = -8846720318135533333L; private final StochasticSupplier supplier; DoubleToLongAdapter(StochasticSupplier supp) { supplier = supp; } @Override public Long get(long seed) { return DoubleMath.roundToLong(supplier.get(seed), RoundingMode.HALF_UP); } } private static class IntegerDistributionSS extends AbstractStochasticSupplier { private static final long serialVersionUID = -7967542154741162460L; private final IntegerDistribution distribution; IntegerDistributionSS(IntegerDistribution id) { distribution = id; } @Override public Integer get(long seed) { distribution.reseedRandomGenerator(seed); return distribution.sample(); } } private static class BoundedDoubleDistSS extends AbstractStochasticSupplier { private static final long serialVersionUID = -6738290534532097051L; private final RealDistribution distribution; private final double lowerBound; private final double upperBound; private final OutOfBoundStrategy outOfBoundStrategy; BoundedDoubleDistSS(RealDistribution rd, double upper, double lower, OutOfBoundStrategy strategy) { checkArgument(strategy == OutOfBoundStrategy.REDRAW || strategy == OutOfBoundStrategy.ROUND); distribution = rd; lowerBound = lower; upperBound = upper; outOfBoundStrategy = strategy; } @Override public Double get(long seed) { distribution.reseedRandomGenerator(seed); double val = distribution.sample(); if (outOfBoundStrategy == OutOfBoundStrategy.REDRAW) { while (!isInBounds(val)) { val = distribution.sample(); } } else if (val < lowerBound) { val = lowerBound; } else if (val >= upperBound) { val = upperBound; } return val; } boolean isInBounds(double val) { return val >= lowerBound && val < upperBound; } } private static class DoubleDistributionSS extends AbstractStochasticSupplier { private static final long serialVersionUID = -5853417575632121095L; private final RealDistribution distribution; DoubleDistributionSS(RealDistribution rd) { distribution = rd; } @Override public Double get(long seed) { distribution.reseedRandomGenerator(seed); return distribution.sample(); } } private static class IteratorSS extends AbstractStochasticSupplier { private static final long serialVersionUID = 3151363361183354655L; private final Iterator iterator; IteratorSS(Iterator it) { iterator = it; } @Override public T get(long seed) { if (iterator.hasNext()) { return iterator.next(); } throw new IllegalStateException("This supplier is exhausted."); } } private static final class ConstantSupplier extends AbstractStochasticSupplier { private static final long serialVersionUID = -5017806121674846656L; private final T value; ConstantSupplier(T v) { value = v; } @Override @Nonnull public T get(long seed) { return value; } @Override public String toString() { return String.format("%s.constant(%s)", StochasticSuppliers.class.getSimpleName(), value); } } private static final class EmptySupplier extends AbstractStochasticSupplier { private static final long serialVersionUID = 1993638453016457007L; private final String message; EmptySupplier(String msg) { message = msg; } @Override @Nonnull public T get(long seed) { throw new IllegalArgumentException(message); } @Override public String toString() { return String.format("%s.empty()", StochasticSuppliers.class.getSimpleName()); } } private static class SupplierAdapter extends AbstractStochasticSupplier { private static final long serialVersionUID = 1388067842132493130L; private final Supplier supplier; SupplierAdapter(Supplier sup) { supplier = sup; } @Override public T get(long seed) { return supplier.get(); } } private static class CheckedSupplier implements StochasticSupplier { private final StochasticSupplier supplier; private final Predicate predicate; CheckedSupplier(StochasticSupplier sup, Predicate pred) { supplier = sup; predicate = pred; } @Override public T get(long seed) { final T value = supplier.get(seed); checkArgument(predicate.apply(value), "The supplier generated an invalid value: %s.", value); return value; } } @AutoValue abstract static class MersenneTwisterSS implements StochasticSupplier { @Override public MersenneTwister get(long seed) { return new MersenneTwister(seed); } static MersenneTwisterSS create() { return new AutoValue_StochasticSuppliers_MersenneTwisterSS(); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy