All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.groupbyinc.flux.next.common.tdunning.math.stats.ArrayDigest Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.tdunning.math.stats;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;

/**
 * Array based implementation of a TDigest.
 * 

* This implementation is essentially a one-level b-tree in which nodes are collected into * pages typically with 32 values per page. Commonly, an ArrayDigest contains 500-3000 * centroids. With 32 values per page, we have about 32 values per page and about 30 pages * which seems to give a nice balance for speed. Sizes from 4 to 100 are plausible, however. */ public class ArrayDigest extends AbstractTDigest { private final int pageSize; private List data = new ArrayList(); private long totalWeight = 0; private int centroidCount = 0; private double compression = 100; public ArrayDigest(int pageSize, double compression) { if (pageSize > 3) { this.pageSize = pageSize; this.compression = compression; } else { throw new IllegalArgumentException("Must have page size of 4 or more"); } } @Override public void add(double x, int w) { checkValue(x); Index start = floor(x); if (start == null) { start = ceiling(x); } if (start == null) { addRaw(x, w); } else { Iterable neighbors = inclusiveTail(start); double minDistance = Double.MAX_VALUE; int lastNeighbor = 0; int i = 0; for (Index neighbor : neighbors) { double z = Math.abs(mean(neighbor) - x); if (z <= minDistance) { minDistance = z; lastNeighbor = i; } else { // as soon as z exceeds the minimum, we have passed the nearest neighbor and can quit break; } i++; } Index closest = null; long sum = headSum(start); i = 0; double n = 0; for (Index neighbor : neighbors) { if (i > lastNeighbor) { break; } double z = Math.abs(mean(neighbor) - x); double q = (sum + count(neighbor) / 2.0) / totalWeight; double k = 4 * totalWeight * q * (1 - q) / compression; // this slightly clever selection method improves accuracy with lots of repeated points if (z == minDistance && count(neighbor) + w <= k) { n++; if (gen.nextDouble() < 1 / n) { closest = neighbor; } } sum += count(neighbor); i++; } if (closest == null) { addRaw(x, w); } else { if (n == 1) { // if the nearest point was unique, centroid ordering cannot change Page p = data.get(closest.page); p.counts[closest.subPage] += w; p.totalCount += w; p.centroids[closest.subPage] += (x - p.centroids[closest.subPage]) / p.counts[closest.subPage]; if (p.history != null && p.history.get(closest.subPage) != null) { p.history.get(closest.subPage).add(x); } totalWeight += w; } else { // if the nearest point was not unique, then we may not be modifying the first copy // which means that ordering can change int weight = count(closest) + w; double center = mean(closest); center = center + (x - center) / weight; if (mean(increment(closest, -1)) <= center && mean(increment(closest, 1)) >= center) { // if order doesn't change, we can short-cut the process Page p = data.get(closest.page); p.counts[closest.subPage] = weight; p.centroids[closest.subPage] = center; p.totalCount += w; totalWeight += w; if (p.history != null && p.history.get(closest.subPage) != null) { p.history.get(closest.subPage).add(x); } } else { delete(closest); List history = history(closest); if (history != null) { history.add(x); } addRaw(center, weight, history); } } } if (centroidCount > 20 * compression) { // something such as sequential ordering of data points // has caused a pathological expansion of our summary. // To fight this, we simply replay the current centroids // in random order. // this causes us to forget the diagnostic recording of data points compress(); } } } public long headSum(Index limit) { long r = 0; for (int i = 0; limit != null && i < limit.page; i++) { r += data.get(i).totalCount; } if (limit != null && limit.page < data.size()) { for (int j = 0; j < limit.subPage; j++) { r += data.get(limit.page).counts[j]; } } return r; } /** * Returns the number of centroids strictly before the limit. */ private int headCount(Index limit) { int r = 0; for (int i = 0; i < limit.page; i++) { r += data.get(i).active; } if (limit.page < data.size()) { for (int j = 0; j < limit.subPage; j++) { r++; } } return r; } public double mean(Index index) { return data.get(index.page).centroids[index.subPage]; } public int count(Index index) { return data.get(index.page).counts[index.subPage]; } @Override public void compress() { ArrayDigest reduced = new ArrayDigest(pageSize, compression); if (recordAllData) { reduced.recordAllData(); } List tmp = new ArrayList(); Iterator ix = this.iterator(0, 0); while (ix.hasNext()) { tmp.add(ix.next()); } Collections.shuffle(tmp, gen); for (Index index : tmp) { reduced.add(mean(index), count(index)); } data = reduced.data; centroidCount = reduced.centroidCount; } @Override public void compress(GroupTree other) { throw new UnsupportedOperationException("Default operation"); } @Override public long size() { return totalWeight; } @Override public double cdf(double x) { if (size() == 0) { return Double.NaN; } else if (size() == 1) { return x < data.get(0).centroids[0] ? 0 : 1; } else { double r = 0; // we scan a across the centroids Iterator it = iterator(0, 0); Index a = it.next(); // b is the look-ahead to the next centroid Index b = it.next(); // initially, we set left width equal to right width double left = (b.mean() - a.mean()) / 2; double right = left; // scan to next to last element while (it.hasNext()) { if (x < a.mean() + right) { return (r + a.count() * AbstractTDigest.interpolate(x, a.mean() - left, a.mean() + right)) / totalWeight; } r += a.count(); a = b; b = it.next(); left = right; right = (b.mean() - a.mean()) / 2; } // for the last element, assume right width is same as left left = right; a = b; if (x < a.mean() + right) { return (r + a.count() * AbstractTDigest.interpolate(x, a.mean() - left, a.mean() + right)) / totalWeight; } else { return 1; } } } @Override public double quantile(double q) { if (q < 0 || q > 1) { throw new IllegalArgumentException("q should be in [0,1], got " + q); } if (centroidCount() == 0) { return Double.NaN; } else if (centroidCount() == 1) { return data.get(0).centroids[0]; } // if values were stored in a sorted array, index would be the offset we are interested in final double index = q * (size() - 1); double previousMean = Double.NaN, previousIndex = 0; long total = 0; // Jump over pages until we reach the page containing the quantile we are interested in int firstPage = 0; while (firstPage < data.size() && total + data.get(firstPage).totalCount < index) { total += data.get(firstPage++).totalCount; } Iterator it; if (firstPage == 0) { // start from the beginning it = iterator(0, 0); } else { final int previousPageIndex = firstPage - 1; final Page previousPage = data.get(previousPageIndex); assert previousPage.active > 0; final int lastSubPage = previousPage.active - 1; previousMean = previousPage.centroids[lastSubPage]; previousIndex = total - (previousPage.counts[lastSubPage] + 1.0) / 2; it = iterator(firstPage, 0); } Index next; while (true) { next = it.next(); final double nextIndex = total + (next.count() - 1.0) / 2; if (nextIndex >= index) { if (Double.isNaN(previousMean)) { assert total == 0; // special case 1: the index we are interested in is before the 1st centroid if (nextIndex == previousIndex) { return next.mean(); } // assume values grow linearly between index previousIndex=0 and nextIndex2 Index next2 = it.next(); final double nextIndex2 = total + next.count() + (next2.count() - 1.0) / 2; previousMean = (nextIndex2 * next.mean() - nextIndex * next2.mean()) / (nextIndex2 - nextIndex); } // common case: we found two centroids previous and next so that the desired quantile is // after 'previous' but before 'next' return quantile(previousIndex, index, nextIndex, previousMean, next.mean()); } else if (!it.hasNext()) { // special case 2: the index we are interested in is beyond the last centroid // again, assume values grow linearly between index previousIndex and (count - 1) // which is the highest possible index final double nextIndex2 = size() - 1; final double nextMean2 = (next.mean() * (nextIndex2 - previousIndex) - previousMean * (nextIndex2 - nextIndex)) / (nextIndex - previousIndex); return quantile(nextIndex, index, nextIndex2, next.mean(), nextMean2); } total += next.count(); previousMean = next.mean(); previousIndex = nextIndex; } } @Override public int centroidCount() { return centroidCount; } @Override public Iterable centroids() { List r = new ArrayList(); Iterator ix = iterator(0, 0); while (ix.hasNext()) { Index index = ix.next(); Page current = data.get(index.page); Centroid centroid = new Centroid(current.centroids[index.subPage], current.counts[index.subPage]); if (current.history != null) { for (double x : current.history.get(index.subPage)) { centroid.insertData(x); } } r.add(centroid); } return r; } public Iterator allAfter(double x) { if (data.size() == 0) { return iterator(0, 0); } else { for (int i = 1; i < data.size(); i++) { if (data.get(i).centroids[0] >= x) { Page previous = data.get(i - 1); for (int j = 0; j < previous.active; j++) { if (previous.centroids[j] > x) { return iterator(i - 1, j); } } return iterator(i, 0); } } Page last = data.get(data.size() - 1); for (int j = 0; j < last.active; j++) { if (last.centroids[j] > x) { return iterator(data.size() - 1, j); } } return iterator(data.size(), 0); } } /** * Returns a cursor pointing to the first element <= x. Exposed only for testing. * @param x The value used to find the cursor. * @return The cursor. */ public Index floor(double x) { Iterator rx = allBefore(x); if (!rx.hasNext()) { return null; } Index r = rx.next(); Index z = r; while (rx.hasNext() && mean(z) == x) { r = z; z = rx.next(); } return r; } public Index ceiling(double x) { Iterator r = allAfter(x); return r.hasNext() ? r.next() : null; } /** * Returns an iterator which will give each element <= to x in non-increasing order. * * @param x The upper bound of all returned elements * @return An iterator that returns elements in non-increasing order. */ public Iterator allBefore(double x) { if (data.size() == 0) { return iterator(0, 0); } else { for (int i = 1; i < data.size(); i++) { if (data.get(i).centroids[0] > x) { Page previous = data.get(i - 1); for (int j = 0; j < previous.active; j++) { if (previous.centroids[j] > x) { return reverse(i - 1, j - 1); } } return reverse(i, -1); } } Page last = data.get(data.size() - 1); for (int j = 0; j < last.active; j++) { if (last.centroids[j] > x) { return reverse(data.size() - 1, j - 1); } } return reverse(data.size(), -1); } } public Index increment(Index x, int delta) { int i = x.page; int j = x.subPage + delta; while (i < data.size() && j >= data.get(i).active) { j -= data.get(i).active; i++; } while (i > 0 && j < 0) { i--; j += data.get(i).active; } return new Index(i, j); } @Override public double compression() { return compression; } /** * Returns an upper bound on the number bytes that will be required to represent this histogram. */ @Override public int byteSize() { return 4 + 8 + 8 + centroidCount * 12; } /** * Returns an upper bound on the number of bytes that will be required to represent this histogram in * the tighter representation. */ @Override public int smallByteSize() { int bound = byteSize(); ByteBuffer buf = ByteBuffer.allocate(bound); asSmallBytes(buf); return buf.position(); } /** * Outputs a histogram as bytes using a particularly cheesy encoding. */ @Override public void asBytes(ByteBuffer buf) { buf.putInt(VERBOSE_ARRAY_DIGEST); buf.putDouble(compression()); buf.putInt(pageSize); buf.putInt(centroidCount); for (Page page : data) { for (int i = 0; i < page.active; i++) { buf.putDouble(page.centroids[i]); } } for (Page page : data) { for (int i = 0; i < page.active; i++) { buf.putInt(page.counts[i]); } } } @Override public void asSmallBytes(ByteBuffer buf) { buf.putInt(SMALL_ARRAY_DIGEST); buf.putDouble(compression()); buf.putInt(pageSize); buf.putInt(centroidCount); double x = 0; for (Page page : data) { for (int i = 0; i < page.active; i++) { double mean = page.centroids[i]; double delta = mean - x; x = mean; buf.putFloat((float) delta); } } for (Page page : data) { for (int i = 0; i < page.active; i++) { int n = page.counts[i]; encode(buf, n); } } } /** * Reads a histogram from a byte buffer * * @return The new histogram structure */ public static ArrayDigest fromBytes(ByteBuffer buf) { int encoding = buf.getInt(); if (encoding == VERBOSE_ENCODING || encoding == VERBOSE_ARRAY_DIGEST) { double compression = buf.getDouble(); int pageSize = 32; if (encoding == VERBOSE_ARRAY_DIGEST) { pageSize = buf.getInt(); } ArrayDigest r = new ArrayDigest(pageSize, compression); int n = buf.getInt(); double[] means = new double[n]; for (int i = 0; i < n; i++) { means[i] = buf.getDouble(); } for (int i = 0; i < n; i++) { r.add(means[i], buf.getInt()); } return r; } else if (encoding == SMALL_ENCODING || encoding == SMALL_ARRAY_DIGEST) { double compression = buf.getDouble(); int pageSize = 32; if (encoding == SMALL_ARRAY_DIGEST) { pageSize = buf.getInt(); } ArrayDigest r = new ArrayDigest(pageSize, compression); int n = buf.getInt(); double[] means = new double[n]; double x = 0; for (int i = 0; i < n; i++) { double delta = buf.getFloat(); x += delta; means[i] = x; } for (int i = 0; i < n; i++) { int z = decode(buf); r.add(means[i], z); } return r; } else { throw new IllegalStateException("Invalid format for serialized histogram"); } } private List history(Index index) { List> h = data.get(index.page).history; return h == null ? null : h.get(index.subPage); } private void delete(Index index) { // don't want to delete empty pages here because other indexes would be screwed up. // this should almost never happen anyway since deletes only cause small ordering // changes totalWeight -= count(index); centroidCount--; data.get(index.page).delete(index.subPage); } private Iterable inclusiveTail(final Index start) { return new Iterable() { @Override public Iterator iterator() { return ArrayDigest.this.iterator(start.page, start.subPage); } }; } void addRaw(double x, int w) { List tmp = new ArrayList(); tmp.add(x); addRaw(x, w, recordAllData ? tmp : null); } void addRaw(double x, int w, List history) { if (centroidCount == 0) { Page page = new Page(pageSize, recordAllData); page.add(x, w, history); totalWeight += w; centroidCount++; data.add(page); } else { for (int i = 1; i < data.size(); i++) { if (data.get(i).centroids[0] > x) { Page newPage = data.get(i - 1).add(x, w, history); totalWeight += w; centroidCount++; if (newPage != null) { data.add(i, newPage); } return; } } Page newPage = data.get(data.size() - 1).add(x, w, history); totalWeight += w; centroidCount++; if (newPage != null) { data.add(data.size(), newPage); } } } @Override void add(double x, int w, Centroid base) { addRaw(x, w, base.data()); } private Iterator iterator(final int startPage, final int startSubPage) { return new Iterator() { int page = startPage; int subPage = startSubPage; Index end = new Index(-1, -1); Index next = null; @Override public boolean hasNext() { if (next == null) { next = computeNext(); } return next != end; } @Override public Index next() { if (hasNext()) { Index r = next; next = null; return r; } else { throw new NoSuchElementException("Can't iterate past end of data"); } } @Override public void remove() { throw new UnsupportedOperationException("Default operation"); } protected Index computeNext() { if (page >= data.size()) { return end; } else { Page current = data.get(page); if (subPage >= current.active) { subPage = 0; page++; return computeNext(); } else { Index r = new Index(page, subPage); subPage++; return r; } } } }; } private Iterator reverse(final int startPage, final int startSubPage) { return new Iterator() { int page = startPage; int subPage = startSubPage; Index end = new Index(-1, -1); Index next = null; @Override public boolean hasNext() { if (next == null) { next = computeNext(); } return next != end; } @Override public Index next() { if (hasNext()) { Index r = next; next = null; return r; } else { throw new NoSuchElementException("Can't reverse iterate before beginning of data"); } } @Override public void remove() { throw new UnsupportedOperationException("Default operation"); } protected Index computeNext() { if (page < 0) { return end; } else { if (subPage < 0) { page--; if (page >= 0) { subPage = data.get(page).active - 1; } return computeNext(); } else { Index r = new Index(page, subPage); subPage--; return r; } } } }; } public final static int VERBOSE_ENCODING = 1; public final static int SMALL_ENCODING = 2; public final static int VERBOSE_ARRAY_DIGEST = 3; public final static int SMALL_ARRAY_DIGEST = 4; class Index { final int page, subPage; private Index(int page, int subPage) { this.page = page; this.subPage = subPage; } double mean() { return data.get(page).centroids[subPage]; } int count() { return data.get(page).counts[subPage]; } } private static class Page { private final boolean recordAllData; private final int pageSize; long totalCount; int active; double[] centroids; int[] counts; List> history; private Page(int pageSize, boolean recordAllData) { this.pageSize = pageSize; this.recordAllData = recordAllData; centroids = new double[this.pageSize]; counts = new int[this.pageSize]; history = this.recordAllData ? new ArrayList>() : null; } public Page add(double x, int w, List history) { for (int i = 0; i < active; i++) { if (centroids[i] >= x) { // insert at i if (active >= pageSize) { // split page Page newPage = split(); if (i < pageSize / 2) { addAt(i, x, w, history); } else { newPage.addAt(i - pageSize / 2, x, w, history); } return newPage; } else { addAt(i, x, w, history); return null; } } } // insert at end if (active >= pageSize) { // split page Page newPage = split(); newPage.addAt(newPage.active, x, w, history); return newPage; } else { addAt(active, x, w, history); return null; } } private void addAt(int i, double x, int w, List history) { if (i < active) { // shift data to make room System.arraycopy(centroids, i, centroids, i + 1, active - i); System.arraycopy(counts, i, counts, i + 1, active - i); if (this.history != null) { this.history.add(i, history); } centroids[i] = x; counts[i] = w; } else { centroids[active] = x; counts[active] = w; if (this.history != null) { this.history.add(history); } } active++; totalCount += w; } private Page split() { assert active == pageSize; final int half = pageSize / 2; Page newPage = new Page(pageSize, recordAllData); System.arraycopy(centroids, half, newPage.centroids, 0, pageSize - half); System.arraycopy(counts, half, newPage.counts, 0, pageSize - half); if (history != null) { newPage.history = new ArrayList>(); newPage.history.addAll(history.subList(half, pageSize)); List> tmp = new ArrayList>(); tmp.addAll(history.subList(0, half)); history = tmp; } active = half; newPage.active = pageSize - half; newPage.totalCount = totalCount; totalCount = 0; for (int i = 0; i < half; i++) { totalCount += counts[i]; newPage.totalCount -= counts[i]; } return newPage; } public void delete(int i) { int w = counts[i]; if (i != active - 1) { System.arraycopy(centroids, i + 1, centroids, i, active - i - 1); System.arraycopy(counts, i + 1, counts, i, active - i - 1); if (history != null) { history.remove(i); } } active--; totalCount -= w; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy