All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.vangj.jbayes.inf.exact.sampling.Table Maven / Gradle / Ivy

The newest version!
package com.github.vangj.jbayes.inf.exact.sampling;

import com.github.vangj.jbayes.inf.exact.graph.Node;
import com.github.vangj.jbayes.inf.exact.graph.Variable;
import com.github.vangj.jbayes.inf.exact.graph.util.NodeUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class Table {

  private final Node node;
  private final List parents;
  private final List parentIds;
  private Map> probs;

  public Table(Node node, List parents) {
    this.node = node;
    parents.sort((n1, n2) -> n1.getId().compareTo(n2.getId()));
    this.parents = parents;
    this.parentIds = this.parents.stream()
        .map(Variable::getId)
        .collect(Collectors.toList());

    this.probs = new HashMap<>();
    if (0 == this.parents.size()) {
      List cumsums = getCumsums(node.probs());
      this.probs.put("default", cumsums);
    } else {
      List> lists = parents.stream()
          .map(Variable::getValues)
          .map(s -> (List) new ArrayList(s))
          .collect(Collectors.toList());

      List> cartesian = NodeUtil.product(lists);
      List keys = cartesian.stream()
          .map(values -> {
            final int paSize = this.parents.size();
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < paSize; i++) {
              Node n = this.parents.get(i);
              String v = values.get(i);
              String s = n.getId() + "=" + v;
              sb.append(s);
              if (i < paSize - 1) {
                sb.append(",");
              }
            }
            return sb.toString();
          })
          .collect(Collectors.toList());

      final int n = node.getValues().size();
      List nodeProbs = node.probs();
      List> probs = NodeUtil.groupList(n, nodeProbs).stream()
          .map(Table::getCumsums)
          .collect(Collectors.toList());

      this.probs = new HashMap<>();
      final int numPairs = keys.size();
      for (int i = 0; i < numPairs; i++) {
        String k = keys.get(i);
        List v = probs.get(i);
        this.probs.put(k, v);
      }
    }
  }

  private static List getCumsums(List probs) {
    final int size = probs.size();
    List cumsums = new ArrayList<>();
    cumsums.add(probs.get(0));
    for (int i = 1; i < size; i++) {
      double sum = cumsums.get(i - 1) + probs.get(i);
      cumsums.add(sum);
    }
    return cumsums;
  }

  public String getValue(final double prob, Map sample) {
    if (!this.hasParents()) {
      List probs = this.probs.get("default");
      int index = NodeUtil.bisectRight(probs, prob);
      return node.getValueList().get(index);
    } else {
      String k = this.parentIds.stream()
          .map(id -> id + "=" + sample.get(id))
          .collect(Collectors.joining(","));
      List probs = this.probs.get(k);
      int index = NodeUtil.bisectRight(probs, prob);
      return node.getValueList().get(index);
    }
  }

  public boolean hasParents() {
    return parents.size() > 0;
  }

  public Map> getProbs() {
    return probs;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy