All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix2D Maven / Gradle / Ivy

Go to download

Carrot2 search results clustering framework. Minimal functional subset (core algorithms and infrastructure, no document sources).

There is a newer version: 3.16.3
Show newest version
/* Imported from Mahout. */package org.carrot2.mahout.math.matrix.impl;

import org.carrot2.mahout.math.function.DoubleDoubleFunction;
import org.carrot2.mahout.math.function.DoubleFunction;
import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.function.Mult;
import org.carrot2.mahout.math.function.PlusMult;
import org.carrot2.mahout.math.matrix.DoubleMatrix1D;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;

public final class DenseDoubleMatrix2D extends DoubleMatrix2D {

  
  final double[] elements;

  
  public DenseDoubleMatrix2D(double[][] values) {
    this(values.length, values.length == 0 ? 0 : values[0].length);
    assign(values);
  }

  
  public DenseDoubleMatrix2D(int rows, int columns) {
    setUp(rows, columns);
    this.elements = new double[rows * columns];
  }

  
  public static DoubleMatrix2D identity(int rowsAndColumns) {
    DoubleMatrix2D matrix = new DenseDoubleMatrix2D(rowsAndColumns, rowsAndColumns);
    for (int i = rowsAndColumns; --i >= 0;) {
      matrix.setQuick(i, i, 1);
    }
    return matrix;
  }

  
  @Override
  public void assign(double[][] values) {
    if (this.isNoView) {
      if (values.length != rows) {
        throw new IllegalArgumentException("Must have same number of rows: rows=" + values.length + "rows()=" + rows());
      }
      int i = columns * (rows - 1);
      for (int row = rows; --row >= 0;) {
        double[] currentRow = values[row];
        if (currentRow.length != columns) {
          throw new IllegalArgumentException(
              "Must have same number of columns in every row: columns=" + currentRow.length + "columns()=" + columns());
        }
        System.arraycopy(currentRow, 0, this.elements, i, columns);
        i -= columns;
      }
    } else {
      super.assign(values);
    }
  }

  
  @Override
  public DoubleMatrix2D assign(double value) {
    double[] elems = this.elements;
    int index = index(0, 0);
    int cs = this.columnStride;
    int rs = this.rowStride;
    for (int row = rows; --row >= 0;) {
      for (int i = index, column = columns; --column >= 0;) {
        elems[i] = value;
        i += cs;
      }
      index += rs;
    }
    return this;
  }

  
  @Override
  public void assign(DoubleFunction function) {
    double[] elems = this.elements;
    if (elems == null) {
      throw new IllegalStateException();
    }
    int index = index(0, 0);
    int cs = this.columnStride;
    int rs = this.rowStride;

    // specialization for speed
    if (function instanceof Mult) { // x[i] = mult*x[i]
      double multiplicator = ((Mult) function).getMultiplicator();
      if (multiplicator == 1) {
        return;
      }
      if (multiplicator == 0) {
        assign(0);
        return;
      }
      for (int row = rows; --row >= 0;) { // the general case
        for (int i = index, column = columns; --column >= 0;) {
          elems[i] *= multiplicator;
          i += cs;
        }
        index += rs;
      }
    } else { // the general case x[i] = f(x[i])
      for (int row = rows; --row >= 0;) {
        for (int i = index, column = columns; --column >= 0;) {
          elems[i] = function.apply(elems[i]);
          i += cs;
        }
        index += rs;
      }
    }
  }

  
  @Override
  public DoubleMatrix2D assign(DoubleMatrix2D source) {
    // overriden for performance only
    if (!(source instanceof DenseDoubleMatrix2D)) {
      return super.assign(source);
    }
    DenseDoubleMatrix2D other = (DenseDoubleMatrix2D) source;
    if (other == this) {
      return this;
    } // nothing to do
    checkShape(other);

    if (this.isNoView && other.isNoView) { // quickest
      System.arraycopy(other.elements, 0, this.elements, 0, this.elements.length);
      return this;
    }

    if (haveSharedCells(other)) {
      DoubleMatrix2D c = other.copy();
      if (!(c instanceof DenseDoubleMatrix2D)) { // should not happen
        return super.assign(other);
      }
      other = (DenseDoubleMatrix2D) c;
    }

    double[] elems = this.elements;
    double[] otherElems = other.elements;
    if (elems == null || otherElems == null) {
      throw new IllegalStateException();
    }
    int cs = this.columnStride;
    int ocs = other.columnStride;
    int rs = this.rowStride;
    int ors = other.rowStride;

    int otherIndex = other.index(0, 0);
    int index = index(0, 0);
    for (int row = rows; --row >= 0;) {
      for (int i = index, j = otherIndex, column = columns; --column >= 0;) {
        elems[i] = otherElems[j];
        i += cs;
        j += ocs;
      }
      index += rs;
      otherIndex += ors;
    }
    return this;
  }

  
  @Override
  public DoubleMatrix2D assign(DoubleMatrix2D y, DoubleDoubleFunction function) {
    // overriden for performance only
    if (!(y instanceof DenseDoubleMatrix2D)) {
      return super.assign(y, function);
    }
    DenseDoubleMatrix2D other = (DenseDoubleMatrix2D) y;
    checkShape(y);

    double[] elems = this.elements;
    double[] otherElems = other.elements;
    if (elems == null || otherElems == null) {
      throw new IllegalStateException();
    }
    int cs = this.columnStride;
    int ocs = other.columnStride;
    int rs = this.rowStride;
    int ors = other.rowStride;

    int otherIndex = other.index(0, 0);
    int index = index(0, 0);

    // specialized for speed
    if (function == Functions.MULT) { // x[i] = x[i] * y[i]
      for (int row = rows; --row >= 0;) {
        for (int i = index, j = otherIndex, column = columns; --column >= 0;) {
          elems[i] *= otherElems[j];
          i += cs;
          j += ocs;
        }
        index += rs;
        otherIndex += ors;
      }
    } else if (function == Functions.DIV) { // x[i] = x[i] / y[i]
      for (int row = rows; --row >= 0;) {
        for (int i = index, j = otherIndex, column = columns; --column >= 0;) {
          elems[i] /= otherElems[j];
          i += cs;
          j += ocs;
        }
        index += rs;
        otherIndex += ors;
      }
    } else if (function instanceof PlusMult) {
      double multiplicator = ((PlusMult) function).getMultiplicator();
      if (multiplicator == 0) { // x[i] = x[i] + 0*y[i]
        return this;
      } else if (multiplicator == 1) { // x[i] = x[i] + y[i]
        for (int row = rows; --row >= 0;) {
          for (int i = index, j = otherIndex, column = columns; --column >= 0;) {
            elems[i] += otherElems[j];
            i += cs;
            j += ocs;
          }
          index += rs;
          otherIndex += ors;
        }
      } else if (multiplicator == -1) { // x[i] = x[i] - y[i]
        for (int row = rows; --row >= 0;) {
          for (int i = index, j = otherIndex, column = columns; --column >= 0;) {
            elems[i] -= otherElems[j];
            i += cs;
            j += ocs;
          }
          index += rs;
          otherIndex += ors;
        }
      } else { // the general case
        for (int row = rows; --row >= 0;) { // x[i] = x[i] + mult*y[i]
          for (int i = index, j = otherIndex, column = columns; --column >= 0;) {
            elems[i] += multiplicator * otherElems[j];
            i += cs;
            j += ocs;
          }
          index += rs;
          otherIndex += ors;
        }
      }
    } else { // the general case x[i] = f(x[i],y[i])
      for (int row = rows; --row >= 0;) {
        for (int i = index, j = otherIndex, column = columns; --column >= 0;) {
          elems[i] = function.apply(elems[i], otherElems[j]);
          i += cs;
          j += ocs;
        }
        index += rs;
        otherIndex += ors;
      }
    }
    return this;
  }

  
  @Override
  public double getQuick(int row, int column) {
    //if (debug) if (column<0 || column>=columns || row<0 || row>=rows)
    // throw new IndexOutOfBoundsException("row:"+row+", column:"+column);
    //return elements[index(row,column)];
    //manually inlined:
    return elements[rowZero + row * rowStride + columnZero + column * columnStride];
  }

  
  @Override
  protected boolean haveSharedCellsRaw(DoubleMatrix2D other) {
    if (other instanceof SelectedDenseDoubleMatrix2D) {
      SelectedDenseDoubleMatrix2D otherMatrix = (SelectedDenseDoubleMatrix2D) other;
      return this.elements == otherMatrix.elements;
    }
    if (other instanceof DenseDoubleMatrix2D) {
      DenseDoubleMatrix2D otherMatrix = (DenseDoubleMatrix2D) other;
      return this.elements == otherMatrix.elements;
    }
    return false;
  }

  
  @Override
  protected int index(int row, int column) {
    // return super.index(row,column);
    // manually inlined for speed:
    return rowZero + row * rowStride + columnZero + column * columnStride;
  }

  
  @Override
  public DoubleMatrix2D like(int rows, int columns) {
    return new DenseDoubleMatrix2D(rows, columns);
  }

  
  @Override
  public DoubleMatrix1D like1D(int size) {
    return new DenseDoubleMatrix1D(size);
  }

  
  @Override
  protected DoubleMatrix1D like1D(int size, int zero, int stride) {
    return new DenseDoubleMatrix1D(size, this.elements, zero, stride);
  }

  
  @Override
  public void setQuick(int row, int column, double value) {
    //if (debug) if (column<0 || column>=columns || row<0 || row>=rows)
    // throw new IndexOutOfBoundsException("row:"+row+", column:"+column);
    //elements[index(row,column)] = value;
    //manually inlined:
    elements[rowZero + row * rowStride + columnZero + column * columnStride] = value;
  }

  
  @Override
  protected DoubleMatrix2D viewSelectionLike(int[] rowOffsets, int[] columnOffsets) {
    return new SelectedDenseDoubleMatrix2D(this.elements, rowOffsets, columnOffsets, 0);
  }

  @Override
  public DoubleMatrix1D zMult(DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA) {
    if (transposeA) {
      return viewDice().zMult(y, z, alpha, beta, false);
    }
    if (z == null) {
      z = new DenseDoubleMatrix1D(this.rows);
    }
    if (!(y instanceof DenseDoubleMatrix1D && z instanceof DenseDoubleMatrix1D)) {
      return super.zMult(y, z, alpha, beta, transposeA);
    }

    if (columns != y.size || rows > z.size) {
      throw new IllegalArgumentException("Incompatible sizes");
    }

    DenseDoubleMatrix1D yy = (DenseDoubleMatrix1D) y;
    DenseDoubleMatrix1D zz = (DenseDoubleMatrix1D) z;
    double[] AElems = this.elements;
    double[] yElems = yy.elements;
    double[] zElems = zz.elements;
    if (AElems == null || yElems == null || zElems == null) {
      throw new IllegalStateException();
    }
    int As = this.columnStride;
    int ys = yy.stride;
    int zs = zz.stride;

    int indexA = index(0, 0);
    int indexY = yy.index(0);
    int indexZ = zz.index(0);

    int cols = columns;
    for (int row = rows; --row >= 0;) {
      double sum = 0;
      // loop unrolled
      int i = indexA - As;
      int j = indexY - ys;
      for (int k = cols % 4; --k >= 0;) {
        sum += AElems[i += As] * yElems[j += ys];
      }
      for (int k = cols / 4; --k >= 0;) {
        sum += AElems[i += As] * yElems[j += ys]
            + AElems[i += As] * yElems[j += ys] 
            + AElems[i += As] * yElems[j += ys]
            + AElems[i += As] * yElems[j += ys];
      }

      zElems[indexZ] = alpha * sum + beta * zElems[indexZ];
      indexA += this.rowStride;
      indexZ += zs;
    }

    return z;
  }

  @Override
  public DoubleMatrix2D zMult(DoubleMatrix2D B, DoubleMatrix2D C, double alpha, double beta, boolean transposeA,
                              boolean transposeB) {
    // overriden for performance only
    if (transposeA) {
      return viewDice().zMult(B, C, alpha, beta, false, transposeB);
    }
    if (B instanceof SparseDoubleMatrix2D) {
      // exploit quick sparse mult
      // A*B = (B' * A')'
      if (C == null) {
        return B.zMult(this, null, alpha, beta, !transposeB, true).viewDice();
      } else {
        B.zMult(this, C.viewDice(), alpha, beta, !transposeB, true);
        return C;
      }
    }
    if (transposeB) {
      return this.zMult(B.viewDice(), C, alpha, beta, transposeA, false);
    }

    int m = rows;
    int n = columns;
    int p = B.columns;
    if (C == null) {
      C = new DenseDoubleMatrix2D(m, p);
    }
    if (!(C instanceof DenseDoubleMatrix2D)) {
      return super.zMult(B, C, alpha, beta, transposeA, transposeB);
    }
    if (B.rows != n) {
      throw new IllegalArgumentException(
          "Matrix2D inner dimensions must agree");
    }
    if (C.rows != m || C.columns != p) {
      throw new IllegalArgumentException(
          "Incompatible result matrix");
    }
    if (this == C || B == C) {
      throw new IllegalArgumentException("Matrices must not be identical");
    }

    DenseDoubleMatrix2D BB = (DenseDoubleMatrix2D) B;
    DenseDoubleMatrix2D CC = (DenseDoubleMatrix2D) C;
    double[] AElems = this.elements;
    double[] BElems = BB.elements;
    double[] CElems = CC.elements;
    if (AElems == null || BElems == null || CElems == null) {
      throw new IllegalStateException();
    }

    int cA = this.columnStride;
    int cB = BB.columnStride;
    int cC = CC.columnStride;

    int rA = this.rowStride;
    int rB = BB.rowStride;
    int rC = CC.rowStride;

    /*
    A is blocked to hide memory latency
        xxxxxxx B
        xxxxxxx
        xxxxxxx
    A
    xxx     xxxxxxx C
    xxx     xxxxxxx
    ---     -------
    xxx     xxxxxxx
    xxx     xxxxxxx
    ---     -------
    xxx     xxxxxxx
    */
    int blockSize = 30000; // * 8 == Level 2 cache in bytes
    //if (n+p == 0) return C;
    //int m_optimal = (BLOCK_SIZE - n*p) / (n+p);
    int mOptimal = (blockSize - n) / (n + 1);
    if (mOptimal <= 0) {
      mOptimal = 1;
    }
    int blocks = m / mOptimal;
    if (m % mOptimal != 0) {
      blocks++;
    }
    int rr = 0;
    while (--blocks >= 0) {
      int jB = BB.index(0, 0);
      int indexA = index(rr, 0);
      int jC = CC.index(rr, 0);
      rr += mOptimal;
      if (blocks == 0) {
        mOptimal += m - rr;
      }

      for (int j = p; --j >= 0;) {
        int iA = indexA;
        int iC = jC;
        for (int i = mOptimal; --i >= 0;) {
          int kA = iA;
          int kB = jB;

          /*
          // not unrolled:
          for (int k = n; --k >= 0; ) {
            //s += getQuick(i,k) * B.getQuick(k,j);
            s += AElems[kA] * BElems[kB];
            kB += rB;
            kA += cA;
          }
          */

          // loop unrolled
          kA -= cA;
          kB -= rB;

          double s = 0;
          for (int k = n % 4; --k >= 0;) {
            s += AElems[kA += cA] * BElems[kB += rB];
          }
          for (int k = n / 4; --k >= 0;) {
            s += AElems[kA += cA] * BElems[kB += rB]
                + AElems[kA += cA] * BElems[kB += rB] 
                + AElems[kA += cA] * BElems[kB += rB]
                + AElems[kA += cA] * BElems[kB += rB];
          }

          CElems[iC] = alpha * s + beta * CElems[iC];
          iA += rA;
          iC += rC;
        }
        jB += cB;
        jC += cC;
      }
    }
    return C;
  }

  
  @Override
  public double zSum() {
    double[] elems = this.elements;
    if (elems == null) {
      throw new IllegalStateException();
    }
    int index = index(0, 0);
    int cs = this.columnStride;
    int rs = this.rowStride;
    double sum = 0;
    for (int row = rows; --row >= 0;) {
      for (int i = index, column = columns; --column >= 0;) {
        sum += elems[i];
        i += cs;
      }
      index += rs;
    }
    return sum;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy