com.cogpunk.math.probability.VariableProbabilityProfileAggregator Maven / Gradle / Ivy
Show all versions of cogpunk-math Show documentation
package com.cogpunk.math.probability;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.cogpunk.math.NumberOperator;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.apache.commons.lang3.builder.EqualsBuilder;
public class VariableProbabilityProfileAggregator implements EventProbabilityProfile {
private EventProbabilityProfile repeatProbabilityProfile;
private EventProbabilityProfile probabilityProfile;
private EventProbabilityProfileAggregationStrategy aggregationStrategy;
private NumberOperator numberOperator;
private VariableZeroHandler variableZeroHandler;
public VariableProbabilityProfileAggregator(EventProbabilityProfile repeatProbabilityProfile,
EventProbabilityProfileAggregationStrategy aggregationStrategy,
NumberOperator numberOperator,
VariableZeroHandler variableZeroHandler,
EventProbabilityProfile probabilityProfile) {
super();
this.repeatProbabilityProfile = repeatProbabilityProfile;
this.aggregationStrategy = aggregationStrategy;
this.numberOperator = numberOperator;
this.variableZeroHandler = variableZeroHandler;
this.probabilityProfile = calculateProbabilityProfile(probabilityProfile);
}
@SuppressWarnings({ "unchecked", "rawtypes" })
private EventProbabilityProfile calculateProbabilityProfile(EventProbabilityProfile probabilityProfile) {
List> aggrProfs= new ArrayList>();
for (Integer r : repeatProbabilityProfile.map().keySet()) {
P prob = repeatProbabilityProfile.getProbability(r);
if (r < 0) {
throw new IllegalArgumentException("The number of repeats cannot be < 0");
} else if (r == 0) {
variableZeroHandler.handleZeroRepeats(aggrProfs, repeatProbabilityProfile);
} else {
List> profs = new ArrayList>();
for (int n = 0; n < r; n++) {
profs.add(probabilityProfile);
}
EventProbabilityProfile thisProbProf = new EventProbabilityProfileAggregator(aggregationStrategy, numberOperator, profs);
// Change probabilities based on probability of this occurring
Map modProbProf = new HashMap();
for (E e : thisProbProf.map().keySet()) {
modProbProf.put(e, numberOperator.multiply(thisProbProf.getProbability(e), prob));
}
aggrProfs.add(new SimpleProbabilityProfileImpl(modProbProf));
}
}
return assembleProbabilityProfile(aggrProfs);
}
private EventProbabilityProfile assembleProbabilityProfile(List> aggrProfs) {
Map map = new HashMap();
for (EventProbabilityProfile prof : aggrProfs) {
for (E e : prof.map().keySet()) {
P prob = prof.getProbability(e);
if (map.containsKey(e)) {
map.put(e, numberOperator.add(map.get(e), prob));
} else {
map.put(e, prob);
}
}
}
return new SimpleProbabilityProfileImpl(map);
}
@Override
public P getProbability(E event) {
return probabilityProfile.getProbability(event);
}
@Override
public Map map() {
return probabilityProfile.map();
}
@Override
public String toString() {
return "VariableProbabilityProfileAggreggator [repeatProbabilityProfile=" + repeatProbabilityProfile
+ ", probabilityProfile="
+ probabilityProfile + ", aggregationStrategy=" + aggregationStrategy + ", numberOperator="
+ numberOperator + "]";
}
/**
* {@inheritDoc}
*/
@Override
public boolean equals(final Object other) {
if (!(other instanceof VariableProbabilityProfileAggregator)) {
return false;
}
VariableProbabilityProfileAggregator,?> castOther = (VariableProbabilityProfileAggregator,?>) other;
return new EqualsBuilder().append(repeatProbabilityProfile, castOther.repeatProbabilityProfile)
.append(probabilityProfile, castOther.probabilityProfile)
.append(aggregationStrategy, castOther.aggregationStrategy)
.append(numberOperator, castOther.numberOperator)
.append(variableZeroHandler, castOther.variableZeroHandler).isEquals();
}
/**
* {@inheritDoc}
*/
@Override
public int hashCode() {
return new HashCodeBuilder().append(repeatProbabilityProfile).append(probabilityProfile)
.append(aggregationStrategy).append(numberOperator).append(variableZeroHandler).toHashCode();
}
}