com.simiacryptus.mindseye.lang.PointSample Maven / Gradle / Ivy
/*
* Copyright (c) 2019 by Andrew Charneski.
*
* The author licenses this file to you under the
* Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance
* with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package com.simiacryptus.mindseye.lang;
import com.simiacryptus.lang.ref.ReferenceCountingBase;
import javax.annotation.Nonnull;
import java.util.UUID;
public final class PointSample extends ReferenceCountingBase {
public final int count;
@Nonnull
public final DeltaSet delta;
public final double sum;
@Nonnull
public final StateSet weights;
public double rate;
public PointSample(@Nonnull final DeltaSet delta, @Nonnull final StateSet weights, final double sum, final double rate, final int count) {
try {
assert delta.getMap().size() == weights.getMap().size();
this.delta = new DeltaSet<>(delta);
this.weights = new StateSet<>(weights);
assert delta.getMap().keySet().stream().allMatch(x -> weights.getMap().containsKey(x));
this.sum = sum;
this.count = count;
setRate(rate);
} catch (RuntimeException e) {
freeRef();
throw e;
} catch (Throwable e) {
freeRef();
throw new RuntimeException(e);
}
}
public static PointSample add(@Nonnull final PointSample left, @Nonnull final PointSample right) {
assert left.delta.getMap().size() == left.weights.getMap().size();
assert right.delta.getMap().size() == right.weights.getMap().size();
assert left.rate == right.rate;
DeltaSet delta = left.delta.add(right.delta);
StateSet stateSet = StateSet.union(left.weights, right.weights);
PointSample pointSample = new PointSample(delta,
stateSet,
left.sum + right.sum,
left.rate,
left.count + right.count);
stateSet.freeRef();
delta.freeRef();
return pointSample;
}
public PointSample add(@Nonnull final PointSample right) {
return PointSample.add(this, right);
}
public PointSample addInPlace(@Nonnull final PointSample right) {
assert delta.getMap().size() == weights.getMap().size();
assert right.delta.getMap().size() == right.weights.getMap().size();
assert rate == right.rate;
return new PointSample(delta.addInPlace(right.delta),
StateSet.union(weights, right.weights),
sum + right.sum,
rate,
count + right.count);
}
public PointSample copyDelta() {
return new PointSample(delta.copy(), weights, sum, rate, count);
}
@Nonnull
public PointSample copyFull() {
@Nonnull DeltaSet deltaCopy = delta.copy();
@Nonnull StateSet weightsCopy = weights.copy();
@Nonnull PointSample pointSample = new PointSample(deltaCopy, weightsCopy, sum, rate, count);
deltaCopy.freeRef();
weightsCopy.freeRef();
return pointSample;
}
public double getMean() {
return sum / count;
}
public double getRate() {
return rate;
}
@Nonnull
public PointSample setRate(final double rate) {
this.rate = rate;
return this;
}
@Nonnull
public PointSample normalize() {
if (count == 1) {
this.addRef();
return this;
} else {
@Nonnull DeltaSet scale = delta.scale(1.0 / count);
@Nonnull PointSample pointSample = new PointSample(scale, weights, sum / count, rate, 1);
scale.freeRef();
return pointSample;
}
}
@Nonnull
public PointSample restore() {
weights.stream().forEach(d -> d.restore());
return this;
}
@Nonnull
public PointSample backup() {
weights.stream().forEach(d -> d.backup());
return this;
}
@Override
public String toString() {
@Nonnull final StringBuffer sb = new StringBuffer("PointSample{");
sb.append("avg=").append(getMean());
sb.append('}');
return sb.toString();
}
@Override
protected void _free() {
if (null != this.weights) this.weights.freeRef();
if (null != this.delta) this.delta.freeRef();
}
@Override
public PointSample addRef() {
return (PointSample) super.addRef();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy