All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.models.gpf.Tensor3 Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.models.gpf;

import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;

/**
 * Created with IntelliJ IDEA.
 * User: irlab
 * Date: 16.05.14
 * Time: 17:27
 * To change this template use File | Settings | File Templates.
 */
public class Tensor3 {
  public final ArrayVec vec;
  public final int dim1;
  public final int dim2;
  public final int dim3;

  public Tensor3(final int dim1, final int dim2, final int dim3) {
    this.dim1 = dim1;
    this.dim2 = dim2;
    this.dim3 = dim3;
    vec = new ArrayVec(dim1 * dim2 * dim3);
  }

  private int index(final int i, final int j, final int k) {
    assert(0 <= i && i < dim1);
    assert(0 <= j && j < dim2);
    assert(0 <= k && k < dim3);
    return dim3 * (dim2 * i + j) + k;
  }

  public double get(final int i, final int j, final int k) {
    return vec.get(index(i, j, k));
  }

  public Tensor3 set(final int i, final int j, final int k, final double val) {
    vec.set(index(i, j, k), val);
    return this;
  }

  public ArrayVec getRow(final int i, final int j) {
    return vec.sub(index(i, j, 0), dim3);
  }

  public Tensor3 setRow(final int i, final int j, final Vec val) {
    if (val.dim() != dim3)
      throw new IllegalArgumentException("val.xdim() != dim3, val.xdim() = " + val.dim() + ", dim3 = " + dim3);
    final int index = index(i, j, 0);
    for (int l = 0; l < dim3; l++)
      vec.set(index + l, val.get(l));
    return this;
  }

  public Tensor3 adjust(final int i, final int j, final int k, final double increment) {
    vec.adjust(index(i, j, k), increment);
    return this;
  }

  public String toString() {
    final StringBuilder builder = new StringBuilder();
    for (int i = 0; i < dim1; i++) {
      for (int j = 0; j < dim2; j++) {
        for (int k = 0; k < dim3; k++) {
          builder.append(k > 0 ? "\t" : "");
          builder.append(get(i, j, k));
        }
        builder.append('\n');
      }
      builder.append('\n');
    }
    return builder.toString();
  }

  public double[] toArray() {
    return vec.toArray();
  }

  public boolean equals(final Object o) {
    return o instanceof Tensor3 && (((Tensor3)o).dim1 == dim1) && (((Tensor3)o).dim2 == dim2) && ((Tensor3)o).vec.equals(vec);
  }

  public int hashCode() {
    return (vec.hashCode() << 1) + dim2 * dim3;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy