org.apache.mahout.math.AbstractVector Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-math Show documentation
Show all versions of mahout-math Show documentation
High performance scientific and technical computing data structures and methods,
mostly based on CERN's
Colt Java API
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math;
import java.util.Iterator;
import com.google.common.base.Preconditions;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
/** Implementations of generic capabilities like sum of elements and dot products */
public abstract class AbstractVector implements Vector, LengthCachingVector {
private int size;
protected double lengthSquared = -1.0;
protected AbstractVector(int size) {
this.size = size;
}
@Override
public Iterable all() {
return new Iterable() {
@Override
public Iterator iterator() {
return AbstractVector.this.iterator();
}
};
}
@Override
public Iterable nonZeroes() {
return new Iterable() {
@Override
public Iterator iterator() {
return iterateNonZero();
}
};
}
/**
* Iterates over all elements * NOTE: Implementations may choose to reuse the Element returned for performance
* reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index
*
* @return An {@link Iterator} over all elements
*/
protected abstract Iterator iterator();
/**
* Iterates over all non-zero elements. NOTE: Implementations may choose to reuse the Element returned for
* performance reasons, so if you need a copy of it, you should call {@link #getElement(int)} for the given index
*
* @return An {@link Iterator} over all non-zero elements
*/
protected abstract Iterator iterateNonZero();
/**
* Aggregates a vector by applying a mapping function fm(x) to every component and aggregating
* the results with an aggregating function fa(x, y).
*
* @param aggregator used to combine the current value of the aggregation with the result of map.apply(nextValue)
* @param map a function to apply to each element of the vector in turn before passing to the aggregator
* @return the result of the aggregation
*/
@Override
public double aggregate(DoubleDoubleFunction aggregator, DoubleFunction map) {
if (size == 0) {
return 0;
}
// If the aggregator is associative and commutative and it's likeLeftMult (fa(0, y) = 0), and there is
// at least one zero in the vector (size > getNumNondefaultElements) and applying fm(0) = 0, the result
// gets cascaded through the aggregation and the final result will be 0.
if (aggregator.isAssociativeAndCommutative() && aggregator.isLikeLeftMult()
&& size > getNumNondefaultElements() && !map.isDensifying()) {
return 0;
}
double result;
if (isSequentialAccess() || aggregator.isAssociativeAndCommutative()) {
Iterator iterator;
// If fm(0) = 0 and fa(x, 0) = x, we can skip all zero values.
if (!map.isDensifying() && aggregator.isLikeRightPlus()) {
iterator = iterateNonZero();
if (!iterator.hasNext()) {
return 0;
}
} else {
iterator = iterator();
}
Element element = iterator.next();
result = map.apply(element.get());
while (iterator.hasNext()) {
element = iterator.next();
result = aggregator.apply(result, map.apply(element.get()));
}
} else {
result = map.apply(getQuick(0));
for (int i = 1; i < size; i++) {
result = aggregator.apply(result, map.apply(getQuick(i)));
}
}
return result;
}
@Override
public double aggregate(Vector other, DoubleDoubleFunction aggregator, DoubleDoubleFunction combiner) {
Preconditions.checkArgument(size == other.size(), "Vector sizes differ");
if (size == 0) {
return 0;
}
return VectorBinaryAggregate.aggregateBest(this, other, aggregator, combiner);
}
/**
* Subclasses must override to return an appropriately sparse or dense result
*
* @param rows the row cardinality
* @param columns the column cardinality
* @return a Matrix
*/
protected abstract Matrix matrixLike(int rows, int columns);
@Override
public Vector viewPart(int offset, int length) {
if (offset < 0) {
throw new IndexException(offset, size);
}
if (offset + length > size) {
throw new IndexException(offset + length, size);
}
return new VectorView(this, offset, length);
}
@SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
@Override
public Vector clone() {
try {
AbstractVector r = (AbstractVector) super.clone();
r.size = size;
r.lengthSquared = lengthSquared;
return r;
} catch (CloneNotSupportedException e) {
throw new IllegalStateException("Can't happen");
}
}
@Override
public Vector divide(double x) {
if (x == 1.0) {
return clone();
}
Vector result = createOptimizedCopy();
for (Element element : result.nonZeroes()) {
element.set(element.get() / x);
}
return result;
}
@Override
public double dot(Vector x) {
if (size != x.size()) {
throw new CardinalityException(size, x.size());
}
if (this == x) {
return getLengthSquared();
}
return aggregate(x, Functions.PLUS, Functions.MULT);
}
protected double dotSelf() {
return aggregate(Functions.PLUS, Functions.pow(2));
}
@Override
public double get(int index) {
if (index < 0 || index >= size) {
throw new IndexException(index, size);
}
return getQuick(index);
}
@Override
public Element getElement(int index) {
return new LocalElement(index);
}
@Override
public Vector normalize() {
return divide(Math.sqrt(getLengthSquared()));
}
@Override
public Vector normalize(double power) {
return divide(norm(power));
}
@Override
public Vector logNormalize() {
return logNormalize(2.0, Math.sqrt(getLengthSquared()));
}
@Override
public Vector logNormalize(double power) {
return logNormalize(power, norm(power));
}
public Vector logNormalize(double power, double normLength) {
// we can special case certain powers
if (Double.isInfinite(power) || power <= 1.0) {
throw new IllegalArgumentException("Power must be > 1 and < infinity");
} else {
double denominator = normLength * Math.log(power);
Vector result = createOptimizedCopy();
for (Element element : result.nonZeroes()) {
element.set(Math.log1p(element.get()) / denominator);
}
return result;
}
}
@Override
public double norm(double power) {
if (power < 0.0) {
throw new IllegalArgumentException("Power must be >= 0");
}
// We can special case certain powers.
if (Double.isInfinite(power)) {
return aggregate(Functions.MAX, Functions.ABS);
} else if (power == 2.0) {
return Math.sqrt(getLengthSquared());
} else if (power == 1.0) {
double result = 0.0;
Iterator iterator = this.iterateNonZero();
while (iterator.hasNext()) {
result += Math.abs(iterator.next().get());
}
return result;
// TODO: this should ideally be used, but it's slower.
// return aggregate(Functions.PLUS, Functions.ABS);
} else if (power == 0.0) {
return getNumNonZeroElements();
} else {
return Math.pow(aggregate(Functions.PLUS, Functions.pow(power)), 1.0 / power);
}
}
@Override
public double getLengthSquared() {
if (lengthSquared >= 0.0) {
return lengthSquared;
}
return lengthSquared = dotSelf();
}
@Override
public void invalidateCachedLength() {
lengthSquared = -1;
}
@Override
public double getDistanceSquared(Vector that) {
if (size != that.size()) {
throw new CardinalityException(size, that.size());
}
double thisLength = getLengthSquared();
double thatLength = that.getLengthSquared();
double dot = dot(that);
double distanceEstimate = thisLength + thatLength - 2 * dot;
if (distanceEstimate > 1.0e-3 * (thisLength + thatLength)) {
// The vectors are far enough from each other that the formula is accurate.
return Math.max(distanceEstimate, 0);
} else {
return aggregate(that, Functions.PLUS, Functions.MINUS_SQUARED);
}
}
@Override
public double maxValue() {
if (size == 0) {
return Double.NEGATIVE_INFINITY;
}
return aggregate(Functions.MAX, Functions.IDENTITY);
}
@Override
public int maxValueIndex() {
int result = -1;
double max = Double.NEGATIVE_INFINITY;
int nonZeroElements = 0;
Iterator iter = this.iterateNonZero();
while (iter.hasNext()) {
nonZeroElements++;
Element element = iter.next();
double tmp = element.get();
if (tmp > max) {
max = tmp;
result = element.index();
}
}
// if the maxElement is negative and the vector is sparse then any
// unfilled element(0.0) could be the maxValue hence we need to
// find one of those elements
if (nonZeroElements < size && max < 0.0) {
for (Element element : all()) {
if (element.get() == 0.0) {
return element.index();
}
}
}
return result;
}
@Override
public double minValue() {
if (size == 0) {
return Double.POSITIVE_INFINITY;
}
return aggregate(Functions.MIN, Functions.IDENTITY);
}
@Override
public int minValueIndex() {
int result = -1;
double min = Double.POSITIVE_INFINITY;
int nonZeroElements = 0;
Iterator iter = this.iterateNonZero();
while (iter.hasNext()) {
nonZeroElements++;
Element element = iter.next();
double tmp = element.get();
if (tmp < min) {
min = tmp;
result = element.index();
}
}
// if the maxElement is positive and the vector is sparse then any
// unfilled element(0.0) could be the maxValue hence we need to
// find one of those elements
if (nonZeroElements < size && min > 0.0) {
for (Element element : all()) {
if (element.get() == 0.0) {
return element.index();
}
}
}
return result;
}
@Override
public Vector plus(double x) {
Vector result = createOptimizedCopy();
if (x == 0.0) {
return result;
}
return result.assign(Functions.plus(x));
}
@Override
public Vector plus(Vector that) {
if (size != that.size()) {
throw new CardinalityException(size, that.size());
}
return createOptimizedCopy().assign(that, Functions.PLUS);
}
@Override
public Vector minus(Vector that) {
if (size != that.size()) {
throw new CardinalityException(size, that.size());
}
return createOptimizedCopy().assign(that, Functions.MINUS);
}
@Override
public void set(int index, double value) {
if (index < 0 || index >= size) {
throw new IndexException(index, size);
}
setQuick(index, value);
}
@Override
public void incrementQuick(int index, double increment) {
setQuick(index, getQuick(index) + increment);
}
@Override
public Vector times(double x) {
if (x == 0.0) {
return like();
}
return createOptimizedCopy().assign(Functions.mult(x));
}
/**
* Copy the current vector in the most optimum fashion. Used by immutable methods like plus(), minus().
* Use this instead of vector.like().assign(vector). Sub-class can choose to override this method.
*
* @return a copy of the current vector.
*/
protected Vector createOptimizedCopy() {
return createOptimizedCopy(this);
}
private static Vector createOptimizedCopy(Vector vector) {
Vector result;
if (vector.isDense()) {
result = vector.like().assign(vector, Functions.SECOND_LEFT_ZERO);
} else {
result = vector.clone();
}
return result;
}
@Override
public Vector times(Vector that) {
if (size != that.size()) {
throw new CardinalityException(size, that.size());
}
if (this.getNumNondefaultElements() <= that.getNumNondefaultElements()) {
return createOptimizedCopy(this).assign(that, Functions.MULT);
} else {
return createOptimizedCopy(that).assign(this, Functions.MULT);
}
}
@Override
public double zSum() {
return aggregate(Functions.PLUS, Functions.IDENTITY);
}
@Override
public int getNumNonZeroElements() {
int count = 0;
Iterator it = iterateNonZero();
while (it.hasNext()) {
if (it.next().get() != 0.0) {
count++;
}
}
return count;
}
@Override
public Vector assign(double value) {
Iterator it;
if (value == 0.0) {
// Make all the non-zero values 0.
it = iterateNonZero();
while (it.hasNext()) {
it.next().set(value);
}
} else {
if (isSequentialAccess() && !isAddConstantTime()) {
// Update all the non-zero values and queue the updates for the zero vaues.
// The vector will become dense.
it = iterator();
OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping();
while (it.hasNext()) {
Element element = it.next();
if (element.get() == 0.0) {
updates.set(element.index(), value);
} else {
element.set(value);
}
}
mergeUpdates(updates);
} else {
for (int i = 0; i < size; ++i) {
setQuick(i, value);
}
}
}
invalidateCachedLength();
return this;
}
@Override
public Vector assign(double[] values) {
if (size != values.length) {
throw new CardinalityException(size, values.length);
}
if (isSequentialAccess() && !isAddConstantTime()) {
OrderedIntDoubleMapping updates = new OrderedIntDoubleMapping();
Iterator it = iterator();
while (it.hasNext()) {
Element element = it.next();
int index = element.index();
if (element.get() == 0.0) {
updates.set(index, values[index]);
} else {
element.set(values[index]);
}
}
mergeUpdates(updates);
} else {
for (int i = 0; i < size; ++i) {
setQuick(i, values[i]);
}
}
invalidateCachedLength();
return this;
}
@Override
public Vector assign(Vector other) {
return assign(other, Functions.SECOND);
}
@Override
public Vector assign(DoubleDoubleFunction f, double y) {
Iterator iterator = f.apply(0, y) == 0 ? iterateNonZero() : iterator();
while (iterator.hasNext()) {
Element element = iterator.next();
element.set(f.apply(element.get(), y));
}
invalidateCachedLength();
return this;
}
@Override
public Vector assign(DoubleFunction f) {
Iterator iterator = !f.isDensifying() ? iterateNonZero() : iterator();
while (iterator.hasNext()) {
Element element = iterator.next();
element.set(f.apply(element.get()));
}
invalidateCachedLength();
return this;
}
@Override
public Vector assign(Vector other, DoubleDoubleFunction function) {
if (size != other.size()) {
throw new CardinalityException(size, other.size());
}
VectorBinaryAssign.assignBest(this, other, function);
invalidateCachedLength();
return this;
}
@Override
public Matrix cross(Vector other) {
Matrix result = matrixLike(size, other.size());
Iterator it = iterateNonZero();
while (it.hasNext()) {
Vector.Element e = it.next();
int row = e.index();
result.assignRow(row, other.times(getQuick(row)));
}
return result;
}
@Override
public final int size() {
return size;
}
@Override
public String asFormatString() {
return toString();
}
@Override
public int hashCode() {
int result = size;
Iterator iter = iterateNonZero();
while (iter.hasNext()) {
Element ele = iter.next();
result += ele.index() * RandomUtils.hashDouble(ele.get());
}
return result;
}
/**
* Determines whether this {@link Vector} represents the same logical vector as another
* object. Two {@link Vector}s are equal (regardless of implementation) if the value at
* each index is the same, and the cardinalities are the same.
*/
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Vector)) {
return false;
}
Vector that = (Vector) o;
return size == that.size() && aggregate(that, Functions.PLUS, Functions.MINUS_ABS) == 0.0;
}
@Override
public String toString() {
return toString(null);
}
public String toString(String[] dictionary) {
StringBuilder result = new StringBuilder();
result.append('{');
for (int index = 0; index < size; index++) {
double value = getQuick(index);
if (value != 0.0) {
result.append(dictionary != null && dictionary.length > index ? dictionary[index] : index);
result.append(':');
result.append(value);
result.append(',');
}
}
if (result.length() > 1) {
result.setCharAt(result.length() - 1, '}');
} else {
result.append('}');
}
return result.toString();
}
/**
* toString() implementation for sparse vectors via {@link #nonZeroes()} method
* @return String representation of the vector
*/
public String sparseVectorToString() {
Iterator it = iterateNonZero();
if (!it.hasNext()) {
return "{}";
}
else {
StringBuilder result = new StringBuilder();
result.append('{');
while (it.hasNext()) {
Vector.Element e = it.next();
result.append(e.index());
result.append(':');
result.append(e.get());
result.append(',');
}
result.setCharAt(result.length() - 1, '}');
return result.toString();
}
}
protected final class LocalElement implements Element {
int index;
LocalElement(int index) {
this.index = index;
}
@Override
public double get() {
return getQuick(index);
}
@Override
public int index() {
return index;
}
@Override
public void set(double value) {
setQuick(index, value);
}
}
}