All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.tribuo.math.optimisers.util.ShrinkingMatrix Maven / Gradle / Ivy

There is a newer version: 4.3.1
Show newest version
/*
 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.tribuo.math.optimisers.util;

import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.MatrixIterator;
import org.tribuo.math.la.MatrixTuple;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;

import java.util.function.DoubleUnaryOperator;

/**
 * A subclass of {@link DenseMatrix} which shrinks the value every time a new value is added.
 * 

* Be careful when modifying this or {@link DenseMatrix}. */ public class ShrinkingMatrix extends DenseMatrix implements ShrinkingTensor { private final double baseRate; private final double lambdaSqrt; private final boolean scaleShrinking; private final boolean reproject; private double squaredTwoNorm; private int iteration; private double multiplier; public ShrinkingMatrix(DenseMatrix v, double baseRate, boolean scaleShrinking) { super(v); this.baseRate = baseRate; this.scaleShrinking = scaleShrinking; this.lambdaSqrt = 0.0; this.reproject = false; this.squaredTwoNorm = 0.0; this.iteration = 1; this.multiplier = 1.0; } public ShrinkingMatrix(DenseMatrix v, double baseRate, double lambda) { super(v); this.baseRate = baseRate; this.scaleShrinking = true; this.lambdaSqrt = Math.sqrt(lambda); this.reproject = true; this.squaredTwoNorm = 0.0; this.iteration = 1; this.multiplier = 1.0; } @Override public DenseMatrix convertToDense() { return new DenseMatrix(this); } @Override public DenseVector leftMultiply(SGDVector input) { if (input.size() == dim2) { double[] output = new double[dim1]; for (VectorTuple tuple : input) { for (int i = 0; i < output.length; i++) { output[i] += get(i, tuple.index) * tuple.value; } } return DenseVector.createDenseVector(output); } else { throw new IllegalArgumentException("input.size() != dim2"); } } @Override public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) { if (other instanceof Matrix) { Matrix otherMat = (Matrix) other; if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) { double shrinkage = scaleShrinking ? 1.0 - (baseRate / iteration) : 1.0 - baseRate; scaleInPlace(shrinkage); for (MatrixTuple tuple : otherMat) { double update = f.applyAsDouble(tuple.value); double oldValue = values[tuple.i][tuple.j] * multiplier; double newValue = oldValue + update; squaredTwoNorm -= oldValue * oldValue; squaredTwoNorm += newValue * newValue; values[tuple.i][tuple.j] = newValue / multiplier; } if (reproject) { double projectionNormaliser = (1.0 / lambdaSqrt) / twoNorm(); if (projectionNormaliser < 1.0) { scaleInPlace(projectionNormaliser); } } iteration++; } else { throw new IllegalStateException("Matrices are not the same size, this(" + dim1 + "," + dim2 + "), other(" + otherMat.getDimension1Size() + "," + otherMat.getDimension2Size() + ")"); } } else { throw new IllegalStateException("Adding a non-Matrix to a Matrix"); } } @Override public double get(int i, int j) { return values[i][j] * multiplier; } @Override public void scaleInPlace(double value) { multiplier *= value; if (Math.abs(multiplier) < tolerance) { reifyMultiplier(); } } private void reifyMultiplier() { for (int i = 0; i < dim1; i++) { for (int j = 0; j < dim2; j++) { values[i][j] *= multiplier; } } multiplier = 1.0; } @Override public double twoNorm() { return Math.sqrt(squaredTwoNorm); } @Override public MatrixIterator iterator() { return new ShrinkingMatrixIterator(this); } private class ShrinkingMatrixIterator implements MatrixIterator { private final ShrinkingMatrix matrix; private final MatrixTuple tuple; private int i; private int j; public ShrinkingMatrixIterator(ShrinkingMatrix matrix) { this.matrix = matrix; this.tuple = new MatrixTuple(); this.i = 0; this.j = 0; } @Override public MatrixTuple getReference() { return tuple; } @Override public boolean hasNext() { return (i < matrix.dim1) && (j < matrix.dim2); } @Override public MatrixTuple next() { tuple.i = i; tuple.j = j; tuple.value = matrix.get(i, j); if (j < dim2 - 1) { j++; } else { //Reached end of current vector, get next one i++; j = 0; } return tuple; } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy