All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.tribuo.math.optimisers.util.ShrinkingVector Maven / Gradle / Ivy

There is a newer version: 4.3.1
Show newest version
/*
 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.tribuo.math.optimisers.util;

import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;

import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;

/**
 * A subclass of {@link DenseVector} which shrinks the value every time a new value is added.
 * 

* Be careful when modifying this or {@link DenseVector}. */ public class ShrinkingVector extends DenseVector implements ShrinkingTensor { private final double baseRate; private final boolean scaleShrinking; private final double lambdaSqrt; private final boolean reproject; private double squaredTwoNorm; private int iteration; private double multiplier; public ShrinkingVector(DenseVector v, double baseRate, boolean scaleShrinking) { super(v); this.baseRate = baseRate; this.scaleShrinking = scaleShrinking; this.lambdaSqrt = 0.0; this.reproject = false; this.iteration = 1; this.multiplier = 1.0; } public ShrinkingVector(DenseVector v, double baseRate, double lambda) { super(v); this.baseRate = baseRate; this.scaleShrinking = true; this.lambdaSqrt = Math.sqrt(lambda); this.reproject = true; this.squaredTwoNorm = 0.0; this.iteration = 1; this.multiplier = 1.0; } private ShrinkingVector(double[] values, double baseRate, boolean scaleShrinking, double lambdaSqrt, boolean reproject, double squaredTwoNorm, int iteration, double multiplier) { super(values); this.baseRate = baseRate; this.scaleShrinking = scaleShrinking; this.lambdaSqrt = lambdaSqrt; this.reproject = reproject; this.squaredTwoNorm = squaredTwoNorm; this.iteration = iteration; this.multiplier = multiplier; } @Override public DenseVector convertToDense() { return DenseVector.createDenseVector(toArray()); } @Override public ShrinkingVector copy() { return new ShrinkingVector(Arrays.copyOf(elements,elements.length),baseRate,scaleShrinking,lambdaSqrt,reproject,squaredTwoNorm,iteration,multiplier); } @Override public double[] toArray() { double[] newValues = new double[elements.length]; for (int i = 0; i < newValues.length; i++) { newValues[i] = get(i); } return newValues; } @Override public double get(int index) { return elements[index] * multiplier; } @Override public double sum() { double sum = 0.0; for (int i = 0; i < elements.length; i++) { sum += get(i); } return sum; } @Override public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) { double shrinkage = scaleShrinking ? 1.0 - (baseRate / iteration) : 1.0 - baseRate; scaleInPlace(shrinkage); SGDVector otherVec = (SGDVector) other; for (VectorTuple tuple : otherVec) { double update = f.applyAsDouble(tuple.value); double oldValue = elements[tuple.index] * multiplier; double newValue = oldValue + update; squaredTwoNorm -= oldValue * oldValue; squaredTwoNorm += newValue * newValue; elements[tuple.index] = newValue / multiplier; } if (reproject) { double projectionNormaliser = (1.0 / lambdaSqrt) / twoNorm(); if (projectionNormaliser < 1.0) { scaleInPlace(projectionNormaliser); } } iteration++; } @Override public int indexOfMax() { int index = 0; double value = Double.NEGATIVE_INFINITY; for (int i = 0; i < elements.length; i++) { double tmp = get(i); if (tmp > value) { index = i; value = tmp; } } return index; } @Override public double dot(SGDVector other) { double score = 0.0; for (VectorTuple tuple : other) { score += get(tuple.index) * tuple.value; } return score; } @Override public void scaleInPlace(double value) { multiplier *= value; if (Math.abs(multiplier) < tolerance) { reifyMultiplier(); } } private void reifyMultiplier() { for (int i = 0; i < elements.length; i++) { elements[i] *= multiplier; } multiplier = 1.0; } @Override public double twoNorm() { return Math.sqrt(squaredTwoNorm); } @Override public double maxValue() { return multiplier * super.maxValue(); } @Override public double minValue() { return multiplier * super.minValue(); } @Override public VectorIterator iterator() { return new ShrinkingVectorIterator(this); } private static class ShrinkingVectorIterator implements VectorIterator { private final ShrinkingVector vector; private final VectorTuple tuple; private int index; public ShrinkingVectorIterator(ShrinkingVector vector) { this.vector = vector; this.tuple = new VectorTuple(); this.index = 0; } @Override public boolean hasNext() { return index < vector.size(); } @Override public VectorTuple next() { tuple.index = index; tuple.value = vector.get(index); index++; return tuple; } @Override public VectorTuple getReference() { return tuple; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy