All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.carrot2.mahout.math.matrix.impl.SparseDoubleMatrix2D Maven / Gradle / Ivy

/* Imported from Mahout. */package org.carrot2.mahout.math.matrix.impl;

import org.carrot2.mahout.math.function.DoubleDoubleFunction;
import org.carrot2.mahout.math.function.DoubleFunction;
import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.function.IntDoubleProcedure;
import org.carrot2.mahout.math.function.IntIntDoubleFunction;
import org.carrot2.mahout.math.function.Mult;
import org.carrot2.mahout.math.function.PlusMult;
import org.carrot2.mahout.math.map.AbstractIntDoubleMap;
import org.carrot2.mahout.math.map.OpenIntDoubleHashMap;
import org.carrot2.mahout.math.matrix.DoubleMatrix1D;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;

public final class SparseDoubleMatrix2D extends DoubleMatrix2D {
  /*
   * The elements of the matrix.
   */
  final AbstractIntDoubleMap elements;

  
  public SparseDoubleMatrix2D(double[][] values) {
    this(values.length, values.length == 0 ? 0 : values[0].length);
    assign(values);
  }

  
  public SparseDoubleMatrix2D(int rows, int columns) {
    this(rows, columns, rows * (columns / 1000), 0.2, 0.5);
  }

  
  public SparseDoubleMatrix2D(int rows, int columns, int initialCapacity, double minLoadFactor, double maxLoadFactor) {
    setUp(rows, columns);
    this.elements = new OpenIntDoubleHashMap(initialCapacity, minLoadFactor, maxLoadFactor);
  }

  
  @Override
  public DoubleMatrix2D assign(double value) {
    // overriden for performance only
    if (this.isNoView && value == 0) {
      this.elements.clear();
    } else {
      super.assign(value);
    }
    return this;
  }

  
  @Override
  public void assign(DoubleFunction function) {
    if (this.isNoView && function instanceof Mult) { // x[i] = mult*x[i]
      this.elements.assign(function);
    } else {
      super.assign(function);
    }
  }

  
  @Override
  public DoubleMatrix2D assign(DoubleMatrix2D source) {
    // overriden for performance only
    if (!(source instanceof SparseDoubleMatrix2D)) {
      return super.assign(source);
    }
    SparseDoubleMatrix2D other = (SparseDoubleMatrix2D) source;
    if (other == this) {
      return this;
    } // nothing to do
    checkShape(other);

    if (this.isNoView && other.isNoView) { // quickest
      this.elements.assign(other.elements);
      return this;
    }
    return super.assign(source);
  }

  @Override
  public DoubleMatrix2D assign(final DoubleMatrix2D y,
                               DoubleDoubleFunction function) {
    if (!this.isNoView) {
      return super.assign(y, function);
    }

    checkShape(y);

    if (function instanceof PlusMult) { // x[i] = x[i] + alpha*y[i]
      final double alpha = ((PlusMult) function).getMultiplicator();
      if (alpha == 0) {
        return this;
      } // nothing to do
      y.forEachNonZero(
          new IntIntDoubleFunction() {
            @Override
            public double apply(int i, int j, double value) {
              setQuick(i, j, getQuick(i, j) + alpha * value);
              return value;
            }
          }
      );
      return this;
    }

    if (function == Functions.MULT) { // x[i] = x[i] * y[i]
      this.elements.forEachPair(
          new IntDoubleProcedure() {
            @Override
            public boolean apply(int key, double value) {
              int i = key / columns;
              int j = key % columns;
              double r = value * y.getQuick(i, j);
              if (r != value) {
                elements.put(key, r);
              }
              return true;
            }
          }
      );
    }

    if (function == Functions.DIV) { // x[i] = x[i] / y[i]
      this.elements.forEachPair(
          new IntDoubleProcedure() {
            @Override
            public boolean apply(int key, double value) {
              int i = key / columns;
              int j = key % columns;
              double r = value / y.getQuick(i, j);
              if (r != value) {
                elements.put(key, r);
              }
              return true;
            }
          }
      );
    }

    return super.assign(y, function);
  }

  
  @Override
  public int cardinality() {
    return this.isNoView ? this.elements.size() : super.cardinality();
  }

  
  @Override
  public void ensureCapacity(int minCapacity) {
    this.elements.ensureCapacity(minCapacity);
  }

  @Override
  public void forEachNonZero(final org.carrot2.mahout.math.function.IntIntDoubleFunction function) {
    if (this.isNoView) {
      this.elements.forEachPair(
          new IntDoubleProcedure() {
            @Override
            public boolean apply(int key, double value) {
              int i = key / columns;
              int j = key % columns;
              double r = function.apply(i, j, value);
              if (r != value) {
                elements.put(key, r);
              }
              return true;
            }
          }
      );
    } else {
      super.forEachNonZero(function);
    }
  }

  
  @Override
  public double getQuick(int row, int column) {
    //if (debug) if (column<0 || column>=columns || row<0 || row>=rows)
    // throw new IndexOutOfBoundsException("row:"+row+", column:"+column);
    //return this.elements.get(index(row,column));
    //manually inlined:
    return this.elements.get(rowZero + row * rowStride + columnZero + column * columnStride);
  }

  
  @Override
  protected boolean haveSharedCellsRaw(DoubleMatrix2D other) {
    if (other instanceof SelectedSparseDoubleMatrix2D) {
      SelectedSparseDoubleMatrix2D otherMatrix = (SelectedSparseDoubleMatrix2D) other;
      return this.elements == otherMatrix.elements;
    }
    if (other instanceof SparseDoubleMatrix2D) {
      SparseDoubleMatrix2D otherMatrix = (SparseDoubleMatrix2D) other;
      return this.elements == otherMatrix.elements;
    }
    return false;
  }

  
  @Override
  protected int index(int row, int column) {
    // return super.index(row,column);
    // manually inlined for speed:
    return rowZero + row * rowStride + columnZero + column * columnStride;
  }

  
  @Override
  public DoubleMatrix2D like(int rows, int columns) {
    return new SparseDoubleMatrix2D(rows, columns);
  }

  
  @Override
  public DoubleMatrix1D like1D(int size) {
    return new SparseDoubleMatrix1D(size);
  }

  
  @Override
  protected DoubleMatrix1D like1D(int size, int offset, int stride) {
    return new SparseDoubleMatrix1D(size, this.elements, offset, stride);
  }

  
  @Override
  public void setQuick(int row, int column, double value) {
    //if (debug) if (column<0 || column>=columns || row<0 || row>=rows)
    // throw new IndexOutOfBoundsException("row:"+row+", column:"+column);
    //int index =  index(row,column);
    //manually inlined:
    int index = rowZero + row * rowStride + columnZero + column * columnStride;

    //if (value == 0 || Math.abs(value) < TOLERANCE)
    if (value == 0) {
      this.elements.removeKey(index);
    } else {
      this.elements.put(index, value);
    }
  }

  
  @Override
  protected DoubleMatrix2D viewSelectionLike(int[] rowOffsets, int[] columnOffsets) {
    return new SelectedSparseDoubleMatrix2D(this.elements, rowOffsets, columnOffsets, 0);
  }

  @Override
  public DoubleMatrix1D zMult(DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, final boolean transposeA) {
    int m = rows;
    int n = columns;
    if (transposeA) {
      m = columns;
      n = rows;
    }

    boolean ignore = z == null;
    if (ignore) {
      z = new DenseDoubleMatrix1D(m);
    }

    if (!(this.isNoView && y instanceof DenseDoubleMatrix1D && z instanceof DenseDoubleMatrix1D)) {
      return super.zMult(y, z, alpha, beta, transposeA);
    }

    if (n != y.size() || m > z.size()) {
      throw new IllegalArgumentException("Incompatible args");
    }

    if (!ignore) {
      z.assign(Functions.mult(beta / alpha));
    }

    DenseDoubleMatrix1D zz = (DenseDoubleMatrix1D) z;
    final double[] zElements = zz.elements;
    final int zStride = zz.stride;
    final int zi = z.index(0);

    DenseDoubleMatrix1D yy = (DenseDoubleMatrix1D) y;
    final double[] yElements = yy.elements;
    final int yStride = yy.stride;
    final int yi = y.index(0);

    if (yElements == null || zElements == null) {
      throw new IllegalStateException();
    }

    this.elements.forEachPair(
        new IntDoubleProcedure() {
          @Override
          public boolean apply(int key, double value) {
            int i = key / columns;
            int j = key % columns;
            if (transposeA) {
              int tmp = i;
              i = j;
              j = tmp;
            }
            zElements[zi + zStride * i] += value * yElements[yi + yStride * j];
            return true;
          }
        }
    );

    if (alpha != 1.0) {
      z.assign(Functions.mult(alpha));
    }
    return z;
  }

  @Override
  public DoubleMatrix2D zMult(DoubleMatrix2D B, DoubleMatrix2D C, final double alpha, double beta,
                              final boolean transposeA, boolean transposeB) {
    if (!this.isNoView) {
      return super.zMult(B, C, alpha, beta, transposeA, transposeB);
    }
    if (transposeB) {
      B = B.viewDice();
    }
    int m = rows;
    int n = columns;
    if (transposeA) {
      m = columns;
      n = rows;
    }
    int p = B.columns;
    boolean ignore = C == null;
    if (C == null) {
      C = new DenseDoubleMatrix2D(m, p);
    }

    if (B.rows != n) {
      throw new IllegalArgumentException("Matrix2D inner dimensions must agree");
    }
    if (C.rows != m || C.columns != p) {
      throw new IllegalArgumentException("Incompatible result matrix");
    }
    if (this == C || B == C) {
      throw new IllegalArgumentException("Matrices must not be identical");
    }

    if (!ignore) {
      C.assign(Functions.mult(beta));
    }

    // cache views
    final DoubleMatrix1D[] Brows = new DoubleMatrix1D[n];
    for (int i = n; --i >= 0;) {
      Brows[i] = B.viewRow(i);
    }
    final DoubleMatrix1D[] Crows = new DoubleMatrix1D[m];
    for (int i = m; --i >= 0;) {
      Crows[i] = C.viewRow(i);
    }

    final PlusMult fun = PlusMult.plusMult(0);

    this.elements.forEachPair(
        new IntDoubleProcedure() {
          @Override
          public boolean apply(int key, double value) {
            int i = key / columns;
            int j = key % columns;
            fun.setMultiplicator(value * alpha);
            if (transposeA) {
              Crows[j].assign(Brows[i], fun);
            } else {
              Crows[i].assign(Brows[j], fun);
            }
            return true;
          }
        }
    );

    return C;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy