edu.cmu.sv.utils.HypothesisSetManagement Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of yoda Show documentation
Show all versions of yoda Show documentation
A library that allows rapid prototyping of dialog systems (language understanding, discourse modelling, dialog management, language generation).
package edu.cmu.sv.utils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Created by David Cohen on 9/4/14.
*
* Functions for managing sets of hypotheses.
*
*/
public class HypothesisSetManagement {
public static List> keepNBestBeam(Map asMap, int beamSize){
Set> ans = asMap.keySet().stream().
map(key -> new ImmutablePair<>(key, asMap.get(key))).
collect(Collectors.toSet());
return keepNBestBeam(ans, beamSize);
}
public static List> keepNBestBeam(Set> fullSet, int beamSize){
List> ans = fullSet.stream().
sorted(Comparator.comparing((Function, Double>) Pair::getValue).reversed()).
collect(Collectors.toList());
if (ans.size()> beamSize){
ans = ans.subList(0,beamSize);
}
return ans;
}
public static List> keepRatioBeam(Map asMap, double ratio, int beamSize){
Set> ans = asMap.keySet().stream().
map(key -> new ImmutablePair<>(key, asMap.get(key))).
collect(Collectors.toSet());
return keepRatioBeam(ans, ratio, beamSize);
}
public static List> keepRatioBeam(Set> fullSet, double ratio, int maxBeamSize){
assert ratio <= 1;
double maxWeight = fullSet.stream().map(Pair::getRight).max(Double::compare).get();
List> ans = fullSet.stream().
sorted(Comparator.comparing((Function, Double>) Pair::getValue).reversed()).
filter(x -> x.getRight() >= ratio * maxWeight).
collect(Collectors.toList());
if (ans.size()> maxBeamSize){
ans = ans.subList(0,maxBeamSize);
}
return ans;
}
public static NBestDistribution keepRatioDistribution(NBestDistribution fullSet, double ratio, int maxBeamSize){
List> nBest = keepRatioBeam(fullSet.getInternalDistribution(), ratio, maxBeamSize);
NBestDistribution ans = new NBestDistribution<>();
nBest.forEach(x -> ans.put(x.getLeft(), x.getRight()));
return ans;
}
public static StringDistribution keepRatioDistribution(StringDistribution fullSet, double ratio, int maxBeamSize){
List> nBest = keepRatioBeam(fullSet.getInternalDistribution(), ratio, maxBeamSize);
StringDistribution ans = new StringDistribution();
nBest.forEach(x -> ans.put(x.getLeft(), x.getRight()));
return ans;
}
public static StringDistribution keepNBestDistribution(StringDistribution fullSet, int beamSize){
List> nBest = keepNBestBeam(fullSet.getInternalDistribution(), beamSize);
StringDistribution ans = new StringDistribution();
nBest.forEach(x -> ans.put(x.getLeft(), x.getRight()));
return ans;
}
public static Pair>>
getJointFromMarginals(Map marginals, int beamSize){
Map> combinationsInput = new HashMap<>();
marginals.entrySet().stream().forEach(x -> combinationsInput.put(x.getKey(), new HashSet<>(x.getValue().keySet())));
StringDistribution jointDistribution = new StringDistribution();
Map> jointAssignments = new HashMap<>();
int i=0;
for (Map jointAssignment : Combination.possibleBindings(combinationsInput)){
Double probability = 1.0;
String assignmentID = "assignmentID_"+i;
for (String key : jointAssignment.keySet()){
probability*=marginals.get(key).get(jointAssignment.get(key));
}
jointDistribution.put(assignmentID, probability);
jointAssignments.put(assignmentID, jointAssignment);
i++;
}
Map> beamOfAssignments = new HashMap<>();
jointDistribution = keepNBestDistribution(jointDistribution, beamSize);
final StringDistribution finalJointDistribution = jointDistribution;
jointAssignments.entrySet().stream().filter(x -> finalJointDistribution.containsKey(x.getKey())).
forEach(x -> beamOfAssignments.put(x.getKey(), x.getValue()));
return new ImmutablePair<>(jointDistribution, beamOfAssignments);
}
}