All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.commons.math3.distribution.MixtureMultivariateRealDistribution Maven / Gradle / Ivy

There is a newer version: 5.0.71
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.commons.math3.distribution;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.MathArithmeticException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.Pair;

/**
 * Class for representing 
 * mixture model distributions.
 *
 * @param  Type of the mixture components.
 *
 * @since 3.1
 */
public class MixtureMultivariateRealDistribution
    extends AbstractMultivariateRealDistribution {
    /** Normalized weight of each mixture component. */
    private final double[] weight;
    /** Mixture components. */
    private final List distribution;

    /**
     * Creates a mixture model from a list of distributions and their
     * associated weights.
     * 

* Note: this constructor will implicitly create an instance of * {@link Well19937c} as random generator to be used for sampling only (see * {@link #sample()} and {@link #sample(int)}). In case no sampling is * needed for the created distribution, it is advised to pass {@code null} * as random generator via the appropriate constructors to avoid the * additional initialisation overhead. * * @param components List of (weight, distribution) pairs from which to sample. */ public MixtureMultivariateRealDistribution(List> components) { this(new Well19937c(), components); } /** * Creates a mixture model from a list of distributions and their * associated weights. * * @param rng Random number generator. * @param components Distributions from which to sample. * @throws NotPositiveException if any of the weights is negative. * @throws DimensionMismatchException if not all components have the same * number of variables. */ public MixtureMultivariateRealDistribution(RandomGenerator rng, List> components) { super(rng, components.get(0).getSecond().getDimension()); final int numComp = components.size(); final int dim = getDimension(); double weightSum = 0; for (int i = 0; i < numComp; i++) { final Pair comp = components.get(i); if (comp.getSecond().getDimension() != dim) { throw new DimensionMismatchException(comp.getSecond().getDimension(), dim); } if (comp.getFirst() < 0) { throw new NotPositiveException(comp.getFirst()); } weightSum += comp.getFirst(); } // Check for overflow. if (Double.isInfinite(weightSum)) { throw new MathArithmeticException(LocalizedFormats.OVERFLOW); } // Store each distribution and its normalized weight. distribution = new ArrayList(); weight = new double[numComp]; for (int i = 0; i < numComp; i++) { final Pair comp = components.get(i); weight[i] = comp.getFirst() / weightSum; distribution.add(comp.getSecond()); } } /** {@inheritDoc} */ public double density(final double[] values) { double p = 0; for (int i = 0; i < weight.length; i++) { p += weight[i] * distribution.get(i).density(values); } return p; } /** {@inheritDoc} */ @Override public double[] sample() { // Sampled values. double[] vals = null; // Determine which component to sample from. final double randomValue = random.nextDouble(); double sum = 0; for (int i = 0; i < weight.length; i++) { sum += weight[i]; if (randomValue <= sum) { // pick model i vals = distribution.get(i).sample(); break; } } if (vals == null) { // This should never happen, but it ensures we won't return a null in // case the loop above has some floating point inequality problem on // the final iteration. vals = distribution.get(weight.length - 1).sample(); } return vals; } /** {@inheritDoc} */ @Override public void reseedRandomGenerator(long seed) { // Seed needs to be propagated to underlying components // in order to maintain consistency between runs. super.reseedRandomGenerator(seed); for (int i = 0; i < distribution.size(); i++) { // Make each component's seed different in order to avoid // using the same sequence of random numbers. distribution.get(i).reseedRandomGenerator(i + 1 + seed); } } /** * Gets the distributions that make up the mixture model. * * @return the component distributions and associated weights. */ public List> getComponents() { final List> list = new ArrayList>(weight.length); for (int i = 0; i < weight.length; i++) { list.add(new Pair(weight[i], distribution.get(i))); } return list; } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy