All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.sv.utils.HypothesisSetManagement Maven / Gradle / Ivy

Go to download

A library that allows rapid prototyping of dialog systems (language understanding, discourse modelling, dialog management, language generation).

There is a newer version: 0.7.0
Show newest version
package edu.cmu.sv.utils;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;

import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * Created by David Cohen on 9/4/14.
 *
 * Functions for managing sets of hypotheses.
 *
 */
public class HypothesisSetManagement {

    public static  List> keepNBestBeam(Map asMap, int beamSize){
        Set> ans = asMap.keySet().stream().
                map(key -> new ImmutablePair<>(key, asMap.get(key))).
                collect(Collectors.toSet());
        return keepNBestBeam(ans, beamSize);
    }

    public static  List> keepNBestBeam(Set> fullSet, int beamSize){
        List> ans = fullSet.stream().
                sorted(Comparator.comparing((Function, Double>) Pair::getValue).reversed()).
                collect(Collectors.toList());

        if (ans.size()> beamSize){
            ans = ans.subList(0,beamSize);
        }
        return ans;
    }


    public static  List> keepRatioBeam(Map asMap, double ratio, int beamSize){
        Set> ans = asMap.keySet().stream().
                map(key -> new ImmutablePair<>(key, asMap.get(key))).
                collect(Collectors.toSet());
        return keepRatioBeam(ans, ratio, beamSize);
    }

    public static  List> keepRatioBeam(Set> fullSet, double ratio, int maxBeamSize){
        assert ratio <= 1;
        double maxWeight = fullSet.stream().map(Pair::getRight).max(Double::compare).get();

        List> ans = fullSet.stream().
                sorted(Comparator.comparing((Function, Double>) Pair::getValue).reversed()).
                filter(x -> x.getRight() >= ratio * maxWeight).
                collect(Collectors.toList());

        if (ans.size()> maxBeamSize){
            ans = ans.subList(0,maxBeamSize);
        }
        return ans;
    }

    public static  NBestDistribution keepRatioDistribution(NBestDistribution fullSet, double ratio, int maxBeamSize){
        List> nBest = keepRatioBeam(fullSet.getInternalDistribution(), ratio, maxBeamSize);
        NBestDistribution ans = new NBestDistribution<>();
        nBest.forEach(x -> ans.put(x.getLeft(), x.getRight()));
        return ans;
    }


    public static StringDistribution keepRatioDistribution(StringDistribution fullSet, double ratio, int maxBeamSize){
        List> nBest = keepRatioBeam(fullSet.getInternalDistribution(), ratio, maxBeamSize);
        StringDistribution ans = new StringDistribution();
        nBest.forEach(x -> ans.put(x.getLeft(), x.getRight()));
        return ans;
    }

    public static StringDistribution keepNBestDistribution(StringDistribution fullSet, int beamSize){
        List> nBest = keepNBestBeam(fullSet.getInternalDistribution(), beamSize);
        StringDistribution ans = new StringDistribution();
        nBest.forEach(x -> ans.put(x.getLeft(), x.getRight()));
        return ans;
    }

    public static Pair>>
    getJointFromMarginals(Map marginals, int beamSize){
        Map> combinationsInput = new HashMap<>();
        marginals.entrySet().stream().forEach(x -> combinationsInput.put(x.getKey(), new HashSet<>(x.getValue().keySet())));
        StringDistribution jointDistribution = new StringDistribution();
        Map> jointAssignments = new HashMap<>();
        int i=0;
        for (Map jointAssignment : Combination.possibleBindings(combinationsInput)){
            Double probability = 1.0;
            String assignmentID = "assignmentID_"+i;
            for (String key : jointAssignment.keySet()){
                probability*=marginals.get(key).get(jointAssignment.get(key));
            }
            jointDistribution.put(assignmentID, probability);
            jointAssignments.put(assignmentID, jointAssignment);
            i++;
        }

        Map> beamOfAssignments = new HashMap<>();
        jointDistribution = keepNBestDistribution(jointDistribution, beamSize);
        final StringDistribution finalJointDistribution = jointDistribution;
        jointAssignments.entrySet().stream().filter(x -> finalJointDistribution.containsKey(x.getKey())).
                forEach(x -> beamOfAssignments.put(x.getKey(), x.getValue()));
        return new ImmutablePair<>(jointDistribution, beamOfAssignments);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy