All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.simiacryptus.util.data.DensityTree Maven / Gradle / Ivy

There is a newer version: 2.1.0
Show newest version
/*
 * Copyright (c) 2019 by Andrew Charneski.
 *
 * The author licenses this file to you under the
 * Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance
 * with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.simiacryptus.util.data;

import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class DensityTree {

  private final CharSequence[] columnNames;
  private double minSplitFract = 0.05;
  private int splitSizeThreshold = 10;
  private double minFitness = 4.0;
  private int maxDepth = Integer.MAX_VALUE;

  public DensityTree(CharSequence... columnNames) {
    this.columnNames = columnNames;
  }

  @javax.annotation.Nonnull
  public Bounds getBounds(@javax.annotation.Nonnull double[][] points) {
    int dim = points[0].length;
    double[] max = IntStream.range(0, dim).mapToDouble(d -> {
      return Arrays.stream(points).mapToDouble(pt -> pt[d]).filter(x -> Double.isFinite(x)).max().orElse(Double.NaN);
    }).toArray();
    double[] min = IntStream.range(0, dim).mapToDouble(d -> {
      return Arrays.stream(points).mapToDouble(pt -> pt[d]).filter(x -> Double.isFinite(x)).min().orElse(Double.NaN);
    }).toArray();
    return new Bounds(max, min);
  }

  public double getMinSplitFract() {
    return minSplitFract;
  }

  @javax.annotation.Nonnull
  public com.simiacryptus.util.data.DensityTree setMinSplitFract(double minSplitFract) {
    this.minSplitFract = minSplitFract;
    return this;
  }

  public int getSplitSizeThreshold() {
    return splitSizeThreshold;
  }

  @javax.annotation.Nonnull
  public com.simiacryptus.util.data.DensityTree setSplitSizeThreshold(int splitSizeThreshold) {
    this.splitSizeThreshold = splitSizeThreshold;
    return this;
  }

  public CharSequence[] getColumnNames() {
    return columnNames;
  }

  public double getMinFitness() {
    return minFitness;
  }

  @javax.annotation.Nonnull
  public com.simiacryptus.util.data.DensityTree setMinFitness(double minFitness) {
    this.minFitness = minFitness;
    return this;
  }

  public int getMaxDepth() {
    return maxDepth;
  }

  @javax.annotation.Nonnull
  public com.simiacryptus.util.data.DensityTree setMaxDepth(int maxDepth) {
    this.maxDepth = maxDepth;
    return this;
  }

  public class Bounds {
    @javax.annotation.Nonnull
    public final double[] max;
    @javax.annotation.Nonnull
    public final double[] min;

    public Bounds(@javax.annotation.Nonnull double[] max, @javax.annotation.Nonnull double[] min) {
      this.max = max;
      this.min = min;
      assert (max.length == min.length);
      assert (IntStream.range(0, max.length).filter(i -> Double.isFinite(max[i])).allMatch(i -> max[i] >= min[i]));
    }

    @javax.annotation.Nonnull
    public Bounds union(@javax.annotation.Nonnull double[] pt) {
      int dim = pt.length;
      return new Bounds(IntStream.range(0, dim).mapToDouble(d -> {
        return Double.isFinite(pt[d]) ? Math.max(max[d], pt[d]) : max[d];
      }).toArray(), IntStream.range(0, dim).mapToDouble(d -> {
        return Double.isFinite(pt[d]) ? Math.min(min[d], pt[d]) : min[d];
      }).toArray());
    }

    public double getVolume() {
      int dim = min.length;
      return IntStream.range(0, dim).mapToDouble(d -> {
        return max[d] - min[d];
      }).filter(x -> Double.isFinite(x) && x > 0.0).reduce((a, b) -> a * b).orElse(Double.NaN);
    }

    @javax.annotation.Nonnull
    public String toString() {
      return "[" + IntStream.range(0, min.length).mapToObj(d -> {
        return String.format("%s: %s - %s", columnNames[d], min[d], max[d]);
      }).reduce((a, b) -> a + "; " + b).get() + "]";
    }

  }

  public class OrthoRule extends Rule {
    private final int dim;
    private final double value;

    public OrthoRule(int dim, double value) {
      super(String.format("%s < %s", columnNames[dim], value));
      this.dim = dim;
      this.value = value;
    }

    @Override
    public boolean eval(double[] pt) {
      return pt[dim] < value;
    }
  }

  public abstract class Rule {
    public final String name;
    public double fitness;

    public Rule(String name) {
      this.name = name;
    }

    public abstract boolean eval(double[] pt);

    @Override
    public String toString() {
      return name;
    }
  }

  public class Node {
    @javax.annotation.Nonnull
    public final double[][] points;
    @javax.annotation.Nonnull
    public final Bounds bounds;
    private final int depth;
    @Nullable
    private Node left = null;
    @Nullable
    private Node right = null;
    @Nullable
    private Rule rule = null;

    public Node(@javax.annotation.Nonnull double[][] points) {
      this(points, 0);
    }

    public Node(@javax.annotation.Nonnull double[][] points, int depth) {
      this.points = points;
      this.bounds = getBounds(points);
      this.depth = depth;
      split();
    }

    public int predict(double[] pt) {
      if (null == rule) {
        return 0;
      } else if (rule.eval(pt)) {
        return 1 + 2 * left.predict(pt);
      } else {
        return 0 + 2 * right.predict(pt);
      }
    }

    @Override
    public String toString() {
      return code();
    }

    public String code() {
      if (null != rule) {
        return String.format("// %s\nif(%s) { // Fitness %s\n  %s\n} else {\n  %s\n}",
            dataInfo(), rule, rule.fitness,
            left.code().replaceAll("\n", "\n  "),
            right.code().replaceAll("\n", "\n  "));
      } else {
        return "// " + dataInfo();
      }
    }

    private CharSequence dataInfo() {
      return String.format("Count: %s Volume: %s Region: %s", points.length, bounds.getVolume(), bounds);
    }

    public void split() {
      if (points.length <= splitSizeThreshold) return;
      if (maxDepth <= depth) return;
      this.rule = IntStream.range(0, points[0].length).mapToObj(x -> x).flatMap(dim -> split_ortho(dim)).filter(x -> Double.isFinite(x.fitness))
          .max(Comparator.comparing(x -> x.fitness)).orElse(null);
      if (null == this.rule) return;
      double[][] leftPts = Arrays.stream(this.points).filter(pt -> rule.eval(pt)).toArray(i -> new double[i][]);
      double[][] rightPts = Arrays.stream(this.points).filter(pt -> !rule.eval(pt)).toArray(i -> new double[i][]);
      assert (leftPts.length + rightPts.length == this.points.length);
      if (rightPts.length == 0 || leftPts.length == 0) return;
      this.left = new Node(leftPts, depth + 1);
      this.right = new Node(rightPts, depth + 1);
    }

    public Stream split_ortho(int dim) {
      double[][] sortedPoints = Arrays.stream(points).filter(pt -> Double.isFinite(pt[dim])).sorted(Comparator.comparing(pt -> pt[dim])).toArray(i -> new double[i][]);
      if (0 == sortedPoints.length) return Stream.empty();
      final int minSize = (int) Math.max(sortedPoints.length * minSplitFract, 1);
      @javax.annotation.Nonnull Bounds[] left = new Bounds[sortedPoints.length];
      @javax.annotation.Nonnull Bounds[] right = new Bounds[sortedPoints.length];
      left[0] = getBounds(new double[][]{sortedPoints[0]});
      right[sortedPoints.length - 1] = getBounds(new double[][]{sortedPoints[sortedPoints.length - 1]});
      for (int i = 1; i < sortedPoints.length; i++) {
        left[i] = left[i - 1].union(sortedPoints[i]);
        right[(sortedPoints.length - 1) - i] = right[((sortedPoints.length - 1) - (i - 1))].union(sortedPoints[(sortedPoints.length - 1) - i]);
      }
      return IntStream.range(1, sortedPoints.length - 1).filter(i -> {
        return sortedPoints[i - 1][dim] < sortedPoints[i][dim];
      }).mapToObj(i -> {
        int leftCount = i;
        int rightCount = sortedPoints.length - leftCount;
        if (minSize >= leftCount || minSize >= rightCount) return null;
        @javax.annotation.Nonnull OrthoRule rule = new OrthoRule(dim, sortedPoints[i][dim]);
        Bounds l = left[i - 1];
        Bounds r = right[i];
        rule.fitness = -(leftCount * Math.log(l.getVolume() / Node.this.bounds.getVolume()) + rightCount * Math.log(r.getVolume() / Node.this.bounds.getVolume())) / (sortedPoints.length * Math.log(2));
        return (Rule) rule;
      }).filter(i -> null != i && i.fitness > minFitness);
    }

    @Nullable
    public Rule getRule() {
      return rule;
    }

    @javax.annotation.Nonnull
    protected Node setRule(Rule rule) {
      this.rule = rule;
      return this;
    }

    @Nullable
    public Node getRight() {
      return right;
    }

    @javax.annotation.Nonnull
    protected Node setRight(Node right) {
      this.right = right;
      return this;
    }

    @Nullable
    public Node getLeft() {
      return left;
    }

    @javax.annotation.Nonnull
    protected Node setLeft(Node left) {
      this.left = left;
      return this;
    }

    public int getDepth() {
      return depth;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy