
org.biojava.nbio.structure.math.SparseVector Maven / Gradle / Ivy
/**
* BioJava development code
*
* This code may be freely distributed and modified under the
* terms of the GNU Lesser General Public Licence. This should
* be distributed with the code. If you do not have a copy,
* see:
*
* http://www.gnu.org/copyleft/lesser.html
*
* Copyright for this code is held jointly by the individual
* authors. These should be listed in @author doc comments.
*
* For more information on the BioJava project and its aims,
* or to join the biojava-l mailing list, visit the home page
* at:
*
* http://www.biojava.org/
*
* Created on 5 Mar 2013
* Created by Andreas Prlic
*
* @since 3.0.6
*/
package org.biojava.nbio.structure.math;
import java.io.Serializable;
/**
*
* A sparse vector, implemented using a symbol table.
*
* Derived from http://introcs.cs.princeton.edu/java/44st/SparseVector.java.html
*
* For additional documentation, see Section 4.4 of
* Introduction to Programming in Java: An Interdisciplinary Approach by Robert Sedgewick and Kevin Wayne.
*/
public class SparseVector implements Serializable{
/**
*
*/
private static final long serialVersionUID = 1174668523213431927L;
private final int N; // length
private SymbolTable symbolTable; // the vector, represented by index-value pairs
/** Constructor. initialize the all 0s vector of length N
*
* @param N
*/
public SparseVector(int N) {
this.N = N;
this.symbolTable = new SymbolTable();
}
/** Setter method (should it be renamed to set?)
*
* @param i set symbolTable[i]
* @param value
*/
public void put(int i, double value) {
if (i < 0 || i >= N) throw new IllegalArgumentException("Illegal index " + i + " should be > 0 and < " + N);
if (value == 0.0) symbolTable.delete(i);
else symbolTable.put(i, value);
}
/** get a value
*
* @param i
* @return return symbolTable[i]
*/
public double get(int i) {
if (i < 0 || i >= N) throw new IllegalArgumentException("Illegal index " + i + " should be > 0 and < " + N);
if (symbolTable.contains(i)) return symbolTable.get(i);
else return 0.0;
}
// return the number of nonzero entries
public int nnz() {
return symbolTable.size();
}
// return the size of the vector
public int size() {
return N;
}
/** Calculates the dot product of this vector a with b
*
* @param b
* @return
*/
public double dot(SparseVector b) {
SparseVector a = this;
if (a.N != b.N) throw new IllegalArgumentException("Vector lengths disagree. " + a.N + " != " + b.N);
double sum = 0.0;
// iterate over the vector with the fewest nonzeros
if (a.symbolTable.size() <= b.symbolTable.size()) {
for (int i : a.symbolTable)
if (b.symbolTable.contains(i)) sum += a.get(i) * b.get(i);
}
else {
for (int i : b.symbolTable)
if (a.symbolTable.contains(i)) sum += a.get(i) * b.get(i);
}
return sum;
}
/** Calculates the 2-norm
*
* @return
*/
public double norm() {
SparseVector a = this;
return Math.sqrt(a.dot(a));
}
/** Calculates alpha * a
*
* @param alpha
* @return
*/
public SparseVector scale(double alpha) {
SparseVector a = this;
SparseVector c = new SparseVector(N);
for (int i : a.symbolTable) c.put(i, alpha * a.get(i));
return c;
}
/** Calcualtes return a + b
*
* @param b
* @return
*/
public SparseVector plus(SparseVector b) {
SparseVector a = this;
if (a.N != b.N) throw new IllegalArgumentException("Vector lengths disagree : " + a.N + " != " + b.N);
SparseVector c = new SparseVector(N);
for (int i : a.symbolTable) c.put(i, a.get(i)); // c = a
for (int i : b.symbolTable) c.put(i, b.get(i) + c.get(i)); // c = c + b
return c;
}
@Override
public String toString() {
StringBuilder s = new StringBuilder();
for (int i : symbolTable) {
s.append("(");
s.append(i);
s.append(", ");
s.append(symbolTable.get(i));
s.append(") ");
}
return s.toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy