All Downloads are FREE. Search and download functionalities are using the official Maven repository.

swim.math.Tensor Maven / Gradle / Ivy

There is a newer version: 3.10.0
Show newest version
// Copyright 2015-2019 SWIM.AI inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package swim.math;

import java.util.Arrays;
import swim.codec.Debug;
import swim.codec.Format;
import swim.codec.Output;
import swim.structure.Item;
import swim.structure.Record;
import swim.structure.Value;
import swim.util.Murmur3;

public class Tensor implements Debug {
  public final TensorDims dims;
  public final Object array;
  public final int offset;

  protected Tensor(TensorDims dims, Object array, int offset) {
    this.dims = dims;
    this.array = array;
    this.offset = offset;
  }

  public Tensor(TensorDims dims, double[] array, int offset) {
    this.dims = dims;
    this.array = array;
    this.offset = offset;
  }

  public Tensor(TensorDims dims, float[] array, int offset) {
    this.dims = dims;
    this.array = array;
    this.offset = offset;
  }

  public Tensor(TensorDims dims, double... array) {
    this.dims = dims;
    this.array = array;
    this.offset = 0;
  }

  public Tensor(TensorDims dims, float... array) {
    this.dims = dims;
    this.array = array;
    this.offset = 0;
  }

  public final TensorDims dimensions() {
    return this.dims;
  }

  public final Precision precision() {
    if (this.array instanceof double[]) {
      return Precision.f64();
    } else if (this.array instanceof float[]) {
      return Precision.f32();
    } else {
      throw new AssertionError();
    }
  }

  protected static int getOffset(TensorDims dim, int[] coords, int offset) {
    int i = 0;
    do {
      final int k = coords[i];
      if (k < 0 || k >= dim.size) {
        throw new IndexOutOfBoundsException(Arrays.toString(coords));
      }
      offset += k * dim.stride;
      dim = dim.next;
      i += 1;
    } while (dim != null);
    return offset;
  }

  public final double getDouble(int... coords) {
    final Object us = this.array;
    if (us instanceof double[]) {
      return ((double[]) us)[getOffset(dims, coords, 0)];
    } else if (us instanceof float[]) {
      return (double) ((float[]) us)[getOffset(dims, coords, 0)];
    } else {
      throw new AssertionError();
    }
  }

  public final float getFloat(int... coords) {
    final Object us = this.array;
    if (us instanceof float[]) {
      return ((float[]) us)[getOffset(dims, coords, 0)];
    } else if (us instanceof double[]) {
      return (float) ((double[]) us)[getOffset(dims, coords, 0)];
    } else {
      throw new AssertionError();
    }
  }

  public final Tensor plus(Tensor that) {
    return add(this, that);
  }

  public static Tensor add(Tensor u, Tensor v) {
    return add(u, v, u.dims, u.precision().max(v.precision()));
  }

  public static void add(Tensor u, Tensor v, MutableTensor w) {
    add(u.dims, u.array, u.offset, v.dims, v.array, v.offset, w.dims, w.array, w.offset);
  }

  public static Tensor add(Tensor u, Tensor v, TensorDims wd, Precision wp) {
    final Object ws;
    if (wp.isDouble()) {
      ws = new double[wd.size * wd.stride];
    } else if (wp.isSingle()) {
      ws = new float[wd.size * wd.stride];
    } else {
      throw new AssertionError();
    }
    add(u.dims, u.array, u.offset, v.dims, v.array, v.offset, wd, ws, 0);
    return new Tensor(wd, ws, 0);
  }

  public static void add(TensorDims ud, Object us, int ui,
                         TensorDims vd, Object vs, int vi,
                         TensorDims wd, Object ws, int wi) {
    if (us instanceof double[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          add(ud, (double[]) us, ui, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          add(ud, (double[]) us, ui, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          add(ud, (double[]) us, ui, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          add(ud, (double[]) us, ui, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else if (us instanceof float[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          add(ud, (float[]) us, ui, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          add(ud, (float[]) us, ui, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          add(ud, (float[]) us, ui, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          add(ud, (float[]) us, ui, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else {
      throw new AssertionError();
    }
  }

  public static void add(TensorDims ud, double[] us, int ui,
                         TensorDims vd, double[] vs, int vi,
                         TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = us[ui] + vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void add(TensorDims ud, double[] us, int ui,
                         TensorDims vd, double[] vs, int vi,
                         TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (us[ui] + vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void add(TensorDims ud, double[] us, int ui,
                         TensorDims vd, float[] vs, int vi,
                         TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = us[ui] + (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void add(TensorDims ud, double[] us, int ui,
                         TensorDims vd, float[] vs, int vi,
                         TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (us[ui] + (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void add(TensorDims ud, float[] us, int ui,
                         TensorDims vd, double[] vs, int vi,
                         TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (double) us[ui] + vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void add(TensorDims ud, float[] us, int ui,
                         TensorDims vd, double[] vs, int vi,
                         TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) ((double) us[ui] + vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void add(TensorDims ud, float[] us, int ui,
                         TensorDims vd, float[] vs, int vi,
                         TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (double) us[ui] + (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void add(TensorDims ud, float[] us, int ui,
                         TensorDims vd, float[] vs, int vi,
                         TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        add(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) ((double) us[ui] + (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public final Tensor opposite() {
    return opposite(this);
  }

  public static Tensor opposite(Tensor u) {
    return opposite(u, u.dims, u.precision());
  }

  public static void opposite(Tensor u, MutableTensor w) {
    opposite(u.dims, u.array, u.offset, w.dims, w.array, w.offset);
  }

  public static Tensor opposite(Tensor u, TensorDims wd, Precision wp) {
    final Object ws;
    if (wp.isDouble()) {
      ws = new double[wd.size * wd.stride];
    } else if (wp.isSingle()) {
      ws = new float[wd.size * wd.stride];
    } else {
      throw new AssertionError();
    }
    opposite(u.dims, u.array, u.offset, wd, ws, 0);
    return new Tensor(wd, ws, 0);
  }

  public static void opposite(TensorDims ud, Object us, int ui,
                              TensorDims wd, Object ws, int wi) {
    if (us instanceof double[]) {
      if (ws instanceof double[]) {
        opposite(ud, (double[]) us, ui, wd, (double[]) ws, wi);
      } else if (ws instanceof float[]) {
        opposite(ud, (double[]) us, ui, wd, (float[]) ws, wi);
      } else {
        throw new AssertionError();
      }
    } else if (us instanceof float[]) {
      if (ws instanceof double[]) {
        opposite(ud, (float[]) us, ui, wd, (double[]) ws, wi);
      } else if (ws instanceof float[]) {
        opposite(ud, (float[]) us, ui, wd, (float[]) ws, wi);
      } else {
        throw new AssertionError();
      }
    } else {
      throw new AssertionError();
    }
  }

  public static void opposite(TensorDims ud, double[] us, int ui,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        opposite(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = -us[ui];
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void opposite(TensorDims ud, double[] us, int ui,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        opposite(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) -us[ui];
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void opposite(TensorDims ud, float[] us, int ui,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        opposite(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = -((double) us[ui]);
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void opposite(TensorDims ud, float[] us, int ui,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        opposite(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = -us[ui];
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public final Tensor minus(Tensor that) {
    return subtract(this, that);
  }

  public static Tensor subtract(Tensor u, Tensor v) {
    return subtract(u, v, u.dims, u.precision().max(v.precision()));
  }

  public static void subtract(Tensor u, Tensor v, MutableTensor w) {
    subtract(u.dims, u.array, u.offset, v.dims, v.array, v.offset, w.dims, w.array, w.offset);
  }

  public static Tensor subtract(Tensor u, Tensor v, TensorDims wd, Precision wp) {
    final Object ws;
    if (wp.isDouble()) {
      ws = new double[wd.size * wd.stride];
    } else if (wp.isSingle()) {
      ws = new float[wd.size * wd.stride];
    } else {
      throw new AssertionError();
    }
    subtract(u.dims, u.array, u.offset, v.dims, v.array, v.offset, wd, ws, 0);
    return new Tensor(wd, ws, 0);
  }

  public static void subtract(TensorDims ud, Object us, int ui,
                              TensorDims vd, Object vs, int vi,
                              TensorDims wd, Object ws, int wi) {
    if (us instanceof double[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          subtract(ud, (double[]) us, ui, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          subtract(ud, (double[]) us, ui, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          subtract(ud, (double[]) us, ui, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          subtract(ud, (double[]) us, ui, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else if (us instanceof float[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          subtract(ud, (float[]) us, ui, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          subtract(ud, (float[]) us, ui, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          subtract(ud, (float[]) us, ui, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          subtract(ud, (float[]) us, ui, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else {
      throw new AssertionError();
    }
  }

  public static void subtract(TensorDims ud, double[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = us[ui] - vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void subtract(TensorDims ud, double[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (us[ui] - vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void subtract(TensorDims ud, double[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = us[ui] - (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void subtract(TensorDims ud, double[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (us[ui] - (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void subtract(TensorDims ud, float[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (double) us[ui] - vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void subtract(TensorDims ud, float[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) ((double) us[ui] - vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void subtract(TensorDims ud, float[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (double) us[ui] - (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void subtract(TensorDims ud, float[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        subtract(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) ((double) us[ui] - (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public final Tensor times(double scalar) {
    return multiply(scalar, this);
  }

  public static Tensor multiply(double a, Tensor u) {
    return multiply(a, u, u.dims, u.precision());
  }

  public static void multiply(double a, Tensor u, MutableTensor w) {
    multiply(a, u.dims, u.array, u.offset, w.dims, w.array, w.offset);
  }

  public static Tensor multiply(double a, Tensor u, TensorDims wd, Precision wp) {
    final Object ws;
    if (wp.isDouble()) {
      ws = new double[wd.size * wd.stride];
    } else if (wp.isSingle()) {
      ws = new float[wd.size * wd.stride];
    } else {
      throw new AssertionError();
    }
    multiply(a, u.dims, u.array, u.offset, wd, ws, 0);
    return new Tensor(wd, ws, 0);
  }

  public static void multiply(double a, TensorDims ud, Object us, int ui,
                              TensorDims wd, Object ws, int wi) {
    if (us instanceof double[]) {
      if (ws instanceof double[]) {
        multiply(a, ud, (double[]) us, ui, wd, (double[]) ws, wi);
      } else if (ws instanceof float[]) {
        multiply(a, ud, (double[]) us, ui, wd, (float[]) ws, wi);
      } else {
        throw new AssertionError();
      }
    } else if (us instanceof float[]) {
      if (ws instanceof double[]) {
        multiply(a, ud, (float[]) us, ui, wd, (double[]) ws, wi);
      } else if (ws instanceof float[]) {
        multiply(a, ud, (float[]) us, ui, wd, (float[]) ws, wi);
      } else {
        throw new AssertionError();
      }
    } else {
      throw new AssertionError();
    }
  }

  public static void multiply(double a, TensorDims ud, double[] us, int ui,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(a, ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = a * us[ui];
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(double a, TensorDims ud, double[] us, int ui,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(a, ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (a * us[ui]);
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(double a, TensorDims ud, float[] us, int ui,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(a, ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (a * (double) us[ui]);
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(double a, TensorDims ud, float[] us, int ui,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(a, ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (a * (double) us[ui]);
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public final Tensor times(Tensor that) {
    return multiply(this, that);
  }

  public static Tensor multiply(Tensor u, Tensor v) {
    return multiply(u, v, u.dims, u.precision().max(v.precision()));
  }

  public static void multiply(Tensor u, Tensor v, MutableTensor w) {
    multiply(u.dims, u.array, u.offset, v.dims, v.array, v.offset, w.dims, w.array, w.offset);
  }

  public static Tensor multiply(Tensor u, Tensor v, TensorDims wd, Precision wp) {
    final Object ws;
    if (wp.isDouble()) {
      ws = new double[wd.size * wd.stride];
    } else if (wp.isSingle()) {
      ws = new float[wd.size * wd.stride];
    } else {
      throw new AssertionError();
    }
    multiply(u.dims, u.array, u.offset, v.dims, v.array, v.offset, wd, ws, 0);
    return new Tensor(wd, ws, 0);
  }

  public static void multiply(TensorDims ud, Object us, int ui,
                              TensorDims vd, Object vs, int vi,
                              TensorDims wd, Object ws, int wi) {
    if (us instanceof double[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          multiply(ud, (double[]) us, ui, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          multiply(ud, (double[]) us, ui, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          multiply(ud, (double[]) us, ui, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          multiply(ud, (double[]) us, ui, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else if (us instanceof float[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          multiply(ud, (float[]) us, ui, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          multiply(ud, (float[]) us, ui, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          multiply(ud, (float[]) us, ui, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          multiply(ud, (float[]) us, ui, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else {
      throw new AssertionError();
    }
  }

  public static void multiply(TensorDims ud, double[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = us[ui] * vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(TensorDims ud, double[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (us[ui] * vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(TensorDims ud, double[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = us[ui] * (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(TensorDims ud, double[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (us[ui] * (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(TensorDims ud, float[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (double) us[ui] * vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(TensorDims ud, float[] us, int ui,
                              TensorDims vd, double[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) ((double) us[ui] * vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(TensorDims ud, float[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (double) us[ui] * (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public static void multiply(TensorDims ud, float[] us, int ui,
                              TensorDims vd, float[] vs, int vi,
                              TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        multiply(ud.next, us, ui, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) ((double) us[ui] * (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += wd.stride;
      }
    }
  }

  public final Tensor timesMatrix(Tensor that) {
    return multiplyMatrix(this, that);
  }

  public static Tensor multiplyMatrix(Tensor u, Tensor v) {
    return multiplyMatrix(u, v, u.dims, u.precision().max(v.precision()));
  }

  public static Tensor multiplyMatrix(Tensor u, Tensor v, TensorDims wd, Precision wp) {
    return multiplyMatrix(u, false, v, false, wd, wp, false);
  }

  public static void multiplyMatrix(Tensor u, boolean ut, Tensor v, boolean vt, MutableTensor w, boolean wt) {
    multiplyMatrix(u.dims, u.array, u.offset, ut, v.dims, v.array, v.offset, vt, w.dims, w.array, w.offset, wt);
  }

  public static Tensor multiplyMatrix(Tensor u, boolean ut, Tensor v, boolean vt, TensorDims wd, Precision wp, boolean wt) {
    final Object ws;
    if (wp.isDouble()) {
      ws = new double[wd.size * wd.stride];
    } else if (wp.isSingle()) {
      ws = new float[wd.size * wd.stride];
    } else {
      throw new AssertionError();
    }
    multiplyMatrix(u.dims, u.array, u.offset, ut, v.dims, v.array, v.offset, vt, wd, ws, 0, wt);
    return new Tensor(wd, ws, 0);
  }

  public static void multiplyMatrix(TensorDims ud, Object us, int ui, boolean ut,
                                    TensorDims vd, Object vs, int vi, boolean vt,
                                    TensorDims wd, Object ws, int wi, boolean wt) {
    if (us instanceof double[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          multiplyMatrix(ud, (double[]) us, ui, ut, vd, (double[]) vs, vi, vt, wd, (double[]) ws, wi, wt);
        } else if (ws instanceof float[]) {
          multiplyMatrix(ud, (double[]) us, ui, ut, vd, (double[]) vs, vi, vt, wd, (float[]) ws, wi, wt);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          multiplyMatrix(ud, (double[]) us, ui, ut, vd, (float[]) vs, vi, vt, wd, (double[]) ws, wi, wt);
        } else if (ws instanceof float[]) {
          multiplyMatrix(ud, (double[]) us, ui, ut, vd, (float[]) vs, vi, vt, wd, (float[]) ws, wi, wt);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else if (us instanceof float[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          multiplyMatrix(ud, (float[]) us, ui, ut, vd, (double[]) vs, vi, vt, wd, (double[]) ws, wi, wt);
        } else if (ws instanceof float[]) {
          multiplyMatrix(ud, (float[]) us, ui, ut, vd, (double[]) vs, vi, vt, wd, (float[]) ws, wi, wt);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          multiplyMatrix(ud, (float[]) us, ui, ut, vd, (float[]) vs, vi, vt, wd, (double[]) ws, wi, wt);
        } else if (ws instanceof float[]) {
          multiplyMatrix(ud, (float[]) us, ui, ut, vd, (float[]) vs, vi, vt, wd, (float[]) ws, wi, wt);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else {
      throw new AssertionError();
    }
  }

  public static void multiplyMatrix(TensorDims ud, double[] us, int ui, boolean ut,
                                    TensorDims vd, double[] vs, int vi, boolean vt,
                                    TensorDims wd, double[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * i;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += us[ui] * vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = dp;
        wi += wc;
      }
    }
  }

  public static void multiplyMatrix(TensorDims ud, double[] us, int ui, boolean ut,
                                    TensorDims vd, double[] vs, int vi, boolean vt,
                                    TensorDims wd, float[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * i;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += us[ui] * vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = (float) dp;
        wi += wc;
      }
    }
  }

  public static void multiplyMatrix(TensorDims ud, double[] us, int ui, boolean ut,
                                    TensorDims vd, float[] vs, int vi, boolean vt,
                                    TensorDims wd, double[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * i;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += us[ui] * (double) vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = dp;
        wi += wc;
      }
    }
  }

  public static void multiplyMatrix(TensorDims ud, double[] us, int ui, boolean ut,
                                    TensorDims vd, float[] vs, int vi, boolean vt,
                                    TensorDims wd, float[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * i;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += us[ui] * (double) vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = (float) dp;
        wi += wc;
      }
    }
  }

  public static void multiplyMatrix(TensorDims ud, float[] us, int ui, boolean ut,
                                    TensorDims vd, double[] vs, int vi, boolean vt,
                                    TensorDims wd, double[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * i;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += (double) us[ui] * vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = dp;
        wi += wc;
      }
    }
  }

  public static void multiplyMatrix(TensorDims ud, float[] us, int ui, boolean ut,
                                    TensorDims vd, double[] vs, int vi, boolean vt,
                                    TensorDims wd, float[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * i;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += (double) us[ui] * vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = (float) dp;
        wi += wc;
      }
    }
  }

  public static void multiplyMatrix(TensorDims ud, float[] us, int ui, boolean ut,
                                    TensorDims vd, float[] vs, int vi, boolean vt,
                                    TensorDims wd, double[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * i;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += (double) us[ui] * (double) vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = dp;
        wi += wc;
      }
    }
  }

  public static void multiplyMatrix(TensorDims ud, float[] us, int ui, boolean ut,
                                    TensorDims vd, float[] vs, int vi, boolean vt,
                                    TensorDims wd, float[] ws, int wi, boolean wt) {
    if (ud.rank() != 2 || vd.rank() != 2 || wd.rank() != 2) {
      throw new DimensionException();
    }
    final int m = ut ? ud.size : ud.next.size;
    final int n = ut ? ud.next.size : ud.size;
    final int p = vt ? vd.next.size : vd.size;
    if ((vt ? vd.size : vd.next.size) != n
        || (wt ? wd.next.size : wd.size) != m
        || (wt ? wd.next.size : wd.size) != p) {
      throw new DimensionException();
    }
    final int uc = ut ? ud.next.stride : ud.stride;
    final int ur = ut ? ud.stride : ud.next.stride;
    final int vc = vt ? vd.next.stride : vd.stride;
    final int vr = vt ? vd.stride : vd.next.stride;
    final int wc = wt ? wd.next.stride : wd.stride;
    final int wr = wt ? wd.stride : wd.next.stride;
    final int ui0 = ui;
    final int vi0 = vi;
    final int wi0 = wi;
    for (int i = 0; i < m; i += 1) {
      wi = wi0 + wr * i;
      for (int j = 0; j < p; j += 1) {
        ui = ui0 + ur * i;
        vi = vi0 + vc * j;
        double dp = 0.0;
        for (int d = 0; d < n; d += 1) {
          dp += (double) us[ui] * (double) vs[vi];
          ui += uc;
          vi += vr;
        }
        ws[wi] = (float) dp;
        wi += wc;
      }
    }
  }

  public static Tensor combine(double a, Tensor u, double b, Tensor v) {
    return combine(a, u, b, v, u.dims, u.precision().max(v.precision()));
  }

  public static void combine(double a, Tensor u, double b, Tensor v, MutableTensor w) {
    combine(a, u.dims, u.array, u.offset, b, v.dims, v.array, v.offset, w.dims, w.array, w.offset);
  }

  public static Tensor combine(double a, Tensor u, double b, Tensor v, TensorDims wd, Precision wp) {
    final Object ws;
    if (wp.isDouble()) {
      ws = new double[wd.size * wd.stride];
    } else if (wp.isSingle()) {
      ws = new float[wd.size * wd.stride];
    } else {
      throw new AssertionError();
    }
    combine(a, u.dims, u.array, u.offset, b, v.dims, v.array, v.offset, wd, ws, 0);
    return new Tensor(wd, ws, 0);
  }

  public static void combine(double a, TensorDims ud, Object us, int ui,
                             double b, TensorDims vd, Object vs, int vi,
                             TensorDims wd, Object ws, int wi) {
    if (us instanceof double[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          combine(a, ud, (double[]) us, ui, b, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          combine(a, ud, (double[]) us, ui, b, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          combine(a, ud, (double[]) us, ui, b, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          combine(a, ud, (double[]) us, ui, b, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else if (us instanceof float[]) {
      if (vs instanceof double[]) {
        if (ws instanceof double[]) {
          combine(a, ud, (float[]) us, ui, b, vd, (double[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          combine(a, ud, (float[]) us, ui, b, vd, (double[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else if (vs instanceof float[]) {
        if (ws instanceof double[]) {
          combine(a, ud, (float[]) us, ui, b, vd, (float[]) vs, vi, wd, (double[]) ws, wi);
        } else if (ws instanceof float[]) {
          combine(a, ud, (float[]) us, ui, b, vd, (float[]) vs, vi, wd, (float[]) ws, wi);
        } else {
          throw new AssertionError();
        }
      } else {
        throw new AssertionError();
      }
    } else {
      throw new AssertionError();
    }
  }

  public static void combine(double a, TensorDims ud, double[] us, int ui,
                             double b, TensorDims vd, double[] vs, int vi,
                             TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = a * us[ui] + b * vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public static void combine(double a, TensorDims ud, double[] us, int ui,
                             double b, TensorDims vd, double[] vs, int vi,
                             TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (a * us[ui] + b * vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public static void combine(double a, TensorDims ud, double[] us, int ui,
                             double b, TensorDims vd, float[] vs, int vi,
                             TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = a * us[ui] + b * (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public static void combine(double a, TensorDims ud, double[] us, int ui,
                             double b, TensorDims vd, float[] vs, int vi,
                             TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (a * us[ui] + b * (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public static void combine(double a, TensorDims ud, float[] us, int ui,
                             double b, TensorDims vd, double[] vs, int vi,
                             TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = a * (double) us[ui] + b * vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public static void combine(double a, TensorDims ud, float[] us, int ui,
                             double b, TensorDims vd, double[] vs, int vi,
                             TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (a * (double) us[ui] + b * vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public static void combine(double a, TensorDims ud, float[] us, int ui,
                             double b, TensorDims vd, float[] vs, int vi,
                             TensorDims wd, double[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = a * (double) us[ui] + b * (double) vs[vi];
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public static void combine(double a, TensorDims ud, float[] us, int ui,
                             double b, TensorDims vd, float[] vs, int vi,
                             TensorDims wd, float[] ws, int wi) {
    if (ud.size != vd.size || ud.size != wd.size || vd.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (wd.next != null) {
      if (ud.next == null || vd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        combine(a, ud.next, us, ui, b, vd.next, vs, vi, wd.next, ws, wi);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    } else {
      if (ud.next != null || vd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) (a * (double) us[ui] + b * (double) vs[vi]);
        ui += ud.stride;
        vi += vd.stride;
        wi += ud.stride;
      }
    }
  }

  public final double[] getDoubleArray() {
    if (array instanceof double[]) {
      return (double[]) this.array;
    } else {
      return null;
    }
  }

  public final float[] getFloatArray() {
    if (array instanceof float[]) {
      return (float[]) this.array;
    } else {
      return null;
    }
  }

  public final int getArrayOffset() {
    return this.offset;
  }

  @Override
  public boolean equals(Object other) {
    if (this == other) {
      return true;
    } else if (other instanceof Tensor) {
      final Tensor that = (Tensor) other;
      final Object us = this.array;
      final Object vs = that.array;
      if (us instanceof double[] && vs instanceof double[]) {
        return equals(dims, (double[]) us, this.offset, that.dims, (double[]) vs, that.offset);
      } else if (us instanceof float[] && vs instanceof float[]) {
        return equals(dims, (float[]) us, this.offset, that.dims, (float[]) vs, that.offset);
      }
    }
    return false;
  }

  static boolean equals(TensorDims ud, double[] us, int ui,
                        TensorDims vd, double[] vs, int vi) {
    if (ud.size != vd.size) {
      return false;
    }
    final int un = ui + ud.size * ud.stride;
    if (ud.next != null) {
      if (vd == null) {
        return false;
      }
      while (ui < un) {
        if (!equals(ud.next, us, ui, vd.next, vs, vi)) {
          return false;
        }
        ui += ud.stride;
        vi += vd.stride;
      }
    } else {
      if (vd.next != null) {
        return false;
      }
      while (ui < un) {
        if (us[ui] != vs[vi]) {
          return false;
        }
        ui += ud.stride;
        vi += vd.stride;
      }
    }
    return true;
  }

  static boolean equals(TensorDims ud, float[] us, int ui,
                        TensorDims vd, float[] vs, int vi) {
    if (ud.size != vd.size) {
      return false;
    }
    final int un = ui + ud.size * ud.stride;
    if (ud.next != null) {
      if (vd == null) {
        return false;
      }
      while (ui < un) {
        if (!equals(ud.next, us, ui, vd.next, vs, vi)) {
          return false;
        }
        ui += ud.stride;
        vi += vd.stride;
      }
    } else {
      if (vd.next != null) {
        return false;
      }
      while (ui < un) {
        if (us[ui] != vs[vi]) {
          return false;
        }
        ui += ud.stride;
        vi += vd.stride;
      }
    }
    return true;
  }

  @Override
  public int hashCode() {
    if (hashSeed == 0) {
      hashSeed = Murmur3.seed(Tensor.class);
    }
    int code = hashSeed;
    final Object us = this.array;
    if (us instanceof double[]) {
      code = hash(code, this.dims, (double[]) us, this.offset);
    } else if (us instanceof float[]) {
      code = hash(code, this.dims, (float[]) us, this.offset);
    } else {
      throw new AssertionError();
    }
    return Murmur3.mash(code);
  }

  static int hash(int code, TensorDims ud, double[] us, int ui) {
    final int limit = ui + ud.size * ud.stride;
    if (ud.next != null) {
      while (ui < limit) {
        hash(code, ud.next, us, ui);
        ui += ud.stride;
      }
    } else {
      while (ui < limit) {
        code = Murmur3.mix(code, Murmur3.hash(us[ui]));
        ui += ud.stride;
      }
    }
    return code;
  }

  static int hash(int code, TensorDims ud, float[] us, int ui) {
    final int limit = ui + ud.size * ud.stride;
    if (ud.next != null) {
      while (ui < limit) {
        hash(code, ud.next, us, ui);
        ui += ud.stride;
      }
    } else {
      while (ui < limit) {
        code = Murmur3.mix(code, Murmur3.hash(us[ui]));
        ui += ud.stride;
      }
    }
    return code;
  }

  @Override
  public void debug(Output output) {
    output = output.write("Tensor").write('.').write("of").write('(')
        .debug(this.dims).write(", ").debug(this.offset);
    final Object us = this.array;
    if (us instanceof double[]) {
      debug(output, (double[]) us);
    } else if (us instanceof float[]) {
      debug(output, (float[]) us);
    } else {
      throw new AssertionError();
    }
    output = output.write(')');
  }

  static void debug(Output output, double[] us) {
    for (int i = 0, n = us.length; i < n; i += 1) {
      output = output.write(", ").debug(us[i]);
    }
  }

  static void debug(Output output, float[] us) {
    for (int i = 0, n = us.length; i < n; i += 1) {
      output = output.write(", ").debug(us[i]);
    }
  }

  @Override
  public String toString() {
    return Format.debug(this);
  }

  private static int hashSeed;

  public static Tensor zero(TensorDims dims) {
    return new Tensor(dims, new float[dims.size * dims.stride]);
  }

  public static Tensor of(TensorDims dims, int offset, double... array) {
    return new Tensor(dims, array, offset);
  }

  public static Tensor of(TensorDims dims, int offset, float... array) {
    return new Tensor(dims, array, offset);
  }

  public static TensorSpace space(TensorSpace next, TensorDims dims) {
    return new TensorObjectSpace(next, dims);
  }

  public static TensorSpace space(TensorSpace next, int n) {
    return new TensorObjectSpace(next, next.dimensions().by(n));
  }

  public static TensorSpace space(TensorDims dims) {
    if (dims.next != null) {
      throw new DimensionException();
    }
    return new TensorObjectSpace(null, dims);
  }

  public static TensorSpace space(int n) {
    return new TensorObjectSpace(null, TensorDims.of(n));
  }

  public static TensorForm form(TensorDims dims, Precision prec) {
    return new TensorObjectForm(dims, prec);
  }

  public static TensorForm form(TensorDims dims) {
    return new TensorObjectForm(dims, Precision.f64());
  }

  public static Item mold(String tag, Tensor u) {
    final Object us = u.array;
    if (us instanceof double[]) {
      return mold(tag, u.dims, (double[]) us, u.offset);
    } else if (us instanceof float[]) {
      return mold(tag, u.dims, (float[]) us, u.offset);
    } else {
      throw new AssertionError();
    }
  }

  public static Item mold(String tag, TensorDims ud, double[] us, int ui) {
    final int un = ui + ud.size * ud.stride;
    final Record header = Record.create(ud.size);
    if (ud.next != null) {
      while (ui < un) {
        header.item(mold(tag, ud.next, us, ui));
        ui += ud.stride;
      }
    } else {
      while (ui < un) {
        header.item(us[ui]);
        ui += ud.stride;
      }
    }
    return Record.create(1).attr(tag, header);
  }

  public static Item mold(String tag, TensorDims ud, float[] us, int ui) {
    final int un = ui + ud.size * ud.stride;
    final Record header = Record.create(ud.size);
    if (ud.next != null) {
      while (ui < un) {
        header.item(mold(tag, ud.next, us, ui));
        ui += ud.stride;
      }
    } else {
      while (ui < un) {
        header.item(us[ui]);
        ui += ud.stride;
      }
    }
    return Record.create(1).attr(tag, header);
  }

  public static Tensor cast(String tag, Item item, TensorDims wd, Precision wp) {
    if (wp.isDouble()) {
      final double[] ws = new double[wd.size * wd.stride];
      cast(tag, item, wd, ws, 0);
      return new Tensor(wd, ws);
    } else if (wp.isSingle()) {
      final float[] ws = new float[wd.size * wd.stride];
      cast(tag, item, wd, ws, 0);
      return new Tensor(wd, ws);
    } else {
      throw new AssertionError();
    }
  }

  public static void cast(String tag, Item item, TensorDims wd, double[] ws, int wi) {
    final Value header = item.toValue().header(tag);
    if (!header.isDefined()) {
      return;
    }
    if (wd.next != null) {
      for (int i = 0; i < wd.size; i += 1) {
        cast(tag, header.getItem(i), wd.next, ws, wi);
        wi += wd.stride;
      }
    } else {
      for (int i = 0; i < wd.size; i += 1) {
        ws[wi] = header.getItem(i).doubleValue(0.0);
        wi += wd.stride;
      }
    }
  }

  public static void cast(String tag, Item item, TensorDims wd, float[] ws, int wi) {
    final Value header = item.toValue().header(tag);
    if (!header.isDefined()) {
      return;
    }
    if (wd.next != null) {
      for (int i = 0; i < wd.size; i += 1) {
        cast(tag, header.getItem(i), wd.next, ws, wi);
        wi += wd.stride;
      }
    } else {
      for (int i = 0; i < wd.size; i += 1) {
        ws[wi] = header.getItem(i).floatValue(0.0f);
        wi += wd.stride;
      }
    }
  }

  public static void copy(TensorDims ud, double[] us, int ui,
                          TensorDims wd, double[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn;
    if (ud.next != null) {
      if (wd.next == null) {
        throw new DimensionException();
      }
      wn = wi + wd.size * wd.stride;
      while (wi < wn) {
        copy(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (wd.next != null) {
        throw new DimensionException();
      }
      if (ud.stride == 1 && wd.stride == 1) {
        System.arraycopy(us, ui, ws, wi, ud.size);
      } else {
        wn = wi + wd.size * wd.stride;
        while (wi < wn) {
          ws[wi] = us[ui];
          ui += ud.stride;
          wi += wd.stride;
        }
      }
    }
  }

  public static void copy(TensorDims ud, double[] us, int ui,
                          TensorDims wd, float[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (ud.next != null) {
      if (wd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        copy(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (wd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (float) us[ui];
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void copy(TensorDims ud, float[] us, int ui,
                          TensorDims wd, double[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn = wi + wd.size * wd.stride;
    if (ud.next != null) {
      if (wd.next == null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        copy(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (wd.next != null) {
        throw new DimensionException();
      }
      while (wi < wn) {
        ws[wi] = (double) us[ui];
        ui += ud.stride;
        wi += wd.stride;
      }
    }
  }

  public static void copy(TensorDims ud, float[] us, int ui,
                          TensorDims wd, float[] ws, int wi) {
    if (ud.size != wd.size) {
      throw new DimensionException();
    }
    final int wn;
    if (ud.next != null) {
      if (wd.next == null) {
        throw new DimensionException();
      }
      wn = wi + wd.size * wd.stride;
      while (wi < wn) {
        copy(ud.next, us, ui, wd.next, ws, wi);
        ui += ud.stride;
        wi += wd.stride;
      }
    } else {
      if (wd.next != null) {
        throw new DimensionException();
      }
      if (ud.stride == 1 && wd.stride == 1) {
        System.arraycopy(us, ui, ws, wi, ud.size);
      } else {
        wn = wi + wd.size * wd.stride;
        while (wi < wn) {
          ws[wi] = us[ui];
          ui += ud.stride;
          wi += wd.stride;
        }
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy