All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.nn.ReplicatedVectorMatrix Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.nn;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Map;
import java.util.Spliterator;
import org.apache.ignite.lang.IgniteUuid;
import org.apache.ignite.ml.math.Matrix;
import org.apache.ignite.ml.math.MatrixStorage;
import org.apache.ignite.ml.math.Vector;
import org.apache.ignite.ml.math.exceptions.CardinalityException;
import org.apache.ignite.ml.math.functions.IgniteBiConsumer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteTriFunction;
import org.apache.ignite.ml.math.functions.IntIntToDoubleFunction;
import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;

/**
 * Convenient way to create matrix of replicated columns or rows from vector.
 * This class should be considered as utility class: not all matrix methods are implemented here, only those which
 * were necessary for MLPs.
 */
class ReplicatedVectorMatrix implements Matrix {
    /**
     * Vector to replicate.
     */
    private Vector vector;

    /**
     * Flag determining is vector replicated as column or row.
     */
    private boolean asCol;

    /**
     * Count of vector replications.
     */
    private int replicationCnt;

    /**
     * Construct ReplicatedVectorMatrix.
     *
     * @param vector Vector to replicate.
     * @param replicationCnt Count of replications.
     * @param asCol Should vector be replicated as a column or as a row.
     */
    ReplicatedVectorMatrix(Vector vector, int replicationCnt, boolean asCol) {
        this.vector = vector;
        this.asCol = asCol;
        this.replicationCnt = replicationCnt;
    }

    /**
     * Constructor for externalization.
     */
    public ReplicatedVectorMatrix() {
        // No-op.
    }

    /** {@inheritDoc} */
    @Override public boolean isSequentialAccess() {
        return vector.isSequentialAccess();
    }

    /** {@inheritDoc} */
    @Override public boolean isRandomAccess() {
        return vector.isRandomAccess();
    }

    /** {@inheritDoc} */
    @Override public boolean isDense() {
        return vector.isDense();
    }

    /** {@inheritDoc} */
    @Override public boolean isArrayBased() {
        return vector.isArrayBased();
    }

    /** {@inheritDoc} */
    @Override public boolean isDistributed() {
        return vector.isDistributed();
    }

    /** {@inheritDoc} */
    @Override public double maxValue() {
        return vector.maxValue();
    }

    /** {@inheritDoc} */
    @Override public double minValue() {
        return vector.minValue();
    }

    /** {@inheritDoc} */
    @Override public Element maxElement() {
        return new Element() {
            @Override public double get() {
                return vector.maxElement().get();
            }

            @Override public int row() {
                return asCol ? vector.maxElement().index() : 0;
            }

            @Override public int column() {
                return asCol ? 0 : vector.maxElement().index();
            }

            @Override public void set(double val) {

            }
        };
    }

    /** {@inheritDoc} */
    @Override public Element minElement() {
        return new Element() {
            @Override public double get() {
                return vector.minElement().get();
            }

            @Override public int row() {
                return asCol ? vector.minElement().index() : 0;
            }

            @Override public int column() {
                return asCol ? 0 : vector.minElement().index();
            }

            @Override public void set(double val) {

            }
        };
    }

    /** {@inheritDoc} */
    @Override public Element getElement(int row, int col) {
        Vector.Element el = asCol ? vector.getElement(row) : vector.getElement(col);
        int r = asCol ? el.index() : 0;
        int c = asCol ? 0 : el.index();

        return new Element() {
            @Override public double get() {
                return el.get();
            }

            @Override public int row() {
                return r;
            }

            @Override public int column() {
                return c;
            }

            @Override public void set(double val) {

            }
        };
    }

    /** {@inheritDoc} */
    @Override public Matrix swapRows(int row1, int row2) {
        return asCol ? new ReplicatedVectorMatrix(swap(row1, row2), replicationCnt, asCol) : this;
    }

    /** {@inheritDoc} */
    private Vector swap(int idx1, int idx2) {
        double val = vector.getX(idx1);

        vector.setX(idx1, vector.getX(idx2));
        vector.setX(idx2, val);

        return vector;
    }

    /** {@inheritDoc} */
    @Override public Matrix swapColumns(int col1, int col2) {
        return asCol ? this : new ReplicatedVectorMatrix(swap(col1, col2), replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public Matrix assign(double val) {
        return new ReplicatedVectorMatrix(vector.assign(val), replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public Matrix assign(double[][] vals) {
        return new DenseLocalOnHeapMatrix(vals);
    }

    /** {@inheritDoc} */
    @Override public Matrix assign(Matrix mtx) {
        return mtx.copy();
    }

    /** {@inheritDoc} */
    @Override public Matrix assign(IntIntToDoubleFunction fun) {
        Vector vec = asCol ? this.vector.assign(idx -> fun.apply(idx, 0)) : this.vector.assign(idx -> fun.apply(0, idx));
        return new ReplicatedVectorMatrix(vec, replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public Matrix map(IgniteDoubleFunction fun) {
        Vector vec = vector.map(fun);
        return new ReplicatedVectorMatrix(vec, replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public Matrix map(Matrix mtx, IgniteBiFunction fun) {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public int nonZeroElements() {
        return vector.nonZeroElements() * (asCol ? columnSize() : rowSize());
    }

    /** {@inheritDoc} */
    @Override public Spliterator allSpliterator() {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Spliterator nonZeroSpliterator() {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Matrix assignColumn(int col, Vector vec) {
        int rows = asCol ? vector.size() : replicationCnt;
        int cols = asCol ? replicationCnt : vector.size();
        int times = asCol ? cols : rows;

        Matrix res = new DenseLocalOnHeapMatrix(rows, cols);

        IgniteBiConsumer replicantAssigner = asCol ? res::assignColumn : res::assignRow;
        IgniteBiConsumer assigner = res::assignColumn;

        assign(replicantAssigner, assigner, vector, vec, times, col);

        return res;
    }

    /** {@inheritDoc} */
    @Override public Matrix assignRow(int row, Vector vec) {
        int rows = asCol ? vector.size() : replicationCnt;
        int cols = asCol ? replicationCnt : vector.size();
        int times = asCol ? cols : rows;

        Matrix res = new DenseLocalOnHeapMatrix(rows, cols);

        IgniteBiConsumer replicantAssigner = asCol ? res::assignColumn : res::assignRow;
        IgniteBiConsumer assigner = res::assignRow;

        assign(replicantAssigner, assigner, vector, vec, times, row);

        return res;
    }

    /** */
    private void assign(IgniteBiConsumer replicantAssigner,
        IgniteBiConsumer assigner, Vector replicant, Vector vector, int times, int idx) {
        for (int i = 0; i < times; i++)
            replicantAssigner.accept(i, replicant);
        assigner.accept(idx, vector);
    }

    /** {@inheritDoc} */
    @Override public Vector foldRows(IgniteFunction fun) {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Vector foldColumns(IgniteFunction fun) {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public  T foldMap(IgniteBiFunction foldFun, IgniteDoubleFunction mapFun,
        T zeroVal) {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public boolean density(double threshold) {
        return false;
    }

    /** {@inheritDoc} */
    @Override public int columnSize() {
        return asCol ? replicationCnt : vector.size();
    }

    /** {@inheritDoc} */
    @Override public int rowSize() {
        return asCol ? vector.size() : replicationCnt;
    }

    /** {@inheritDoc} */
    @Override public double determinant() {
        // If matrix is not square throw exception.
        checkCardinality(vector.size(), replicationCnt);

        // If matrix is 1x1 then determinant is its single element otherwise there are linear dependence and determinant is 0.
        return vector.size() > 0 ? 0 : vector.get(1);
    }

    /** {@inheritDoc} */
    @Override public Matrix inverse() {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Matrix divide(double x) {
        return new ReplicatedVectorMatrix(vector.divide(x), replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public double get(int row, int col) {
        return asCol ? vector.get(row) : vector.get(col);
    }

    /** {@inheritDoc} */
    @Override public double getX(int row, int col) {
        return asCol ? vector.getX(row) : vector.getX(col);
    }

    /** {@inheritDoc} */
    @Override public MatrixStorage getStorage() {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Matrix copy() {
        Vector cp = vector.copy();
        return new ReplicatedVectorMatrix(cp, replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public Matrix like(int rows, int cols) {
        Vector lk = vector.like(vector.size());
        return new ReplicatedVectorMatrix(lk, replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public Vector likeVector(int crd) {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Matrix minus(Matrix mtx) {
        throw new UnsupportedOperationException();
    }

    /**
     * Specialized optimized version of minus for ReplicatedVectorMatrix.
     *
     * @param mtx Matrix to be subtracted.
     * @return new ReplicatedVectorMatrix resulting from subtraction.
     */
    public Matrix minus(ReplicatedVectorMatrix mtx) {
        if (isColumnReplicated() == mtx.isColumnReplicated()) {
            checkCardinality(mtx.rowSize(), mtx.columnSize());

            Vector minus = vector.minus(mtx.replicant());

            return new ReplicatedVectorMatrix(minus, replicationCnt, asCol);
        }

        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Matrix plus(double x) {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Matrix plus(Matrix mtx) {
        throw new UnsupportedOperationException();
    }

    /**
     * Specialized optimized version of plus for ReplicatedVectorMatrix.
     *
     * @param mtx Matrix to be added.
     * @return new ReplicatedVectorMatrix resulting from addition.
     */
    public Matrix plus(ReplicatedVectorMatrix mtx) {
        if (isColumnReplicated() == mtx.isColumnReplicated()) {
            checkCardinality(mtx.rowSize(), mtx.columnSize());

            Vector plus = vector.plus(mtx.replicant());

            return new ReplicatedVectorMatrix(plus, replicationCnt, asCol);
        }

        throw new UnsupportedOperationException();
    }

    /**
     * Checks that dimensions of this matrix are equal to given dimensions.
     *
     * @param rows Rows.
     * @param cols Columns.
     */
    private void checkCardinality(int rows, int cols) {
        if (rows != rowSize())
            throw new CardinalityException(rowSize(), rows);

        if (cols != columnSize())
            throw new CardinalityException(columnSize(), cols);
    }

    /** {@inheritDoc} */
    @Override public IgniteUuid guid() {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Matrix set(int row, int col, double val) {
        vector.set(asCol ? row : col, val);

        return this;
    }

    /** {@inheritDoc} */
    @Override public Matrix setRow(int row, double[] data) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Vector getRow(int row) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Matrix setColumn(int col, double[] data) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Vector getCol(int col) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Matrix setX(int row, int col, double val) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Matrix times(double x) {
        return new ReplicatedVectorMatrix(vector.times(x), replicationCnt, asCol);
    }

    /** {@inheritDoc} */
    @Override public Matrix times(Matrix mtx) {
        if (!asCol) {
            Vector row = vector.like(mtx.columnSize());

            for (int i = 0; i < mtx.columnSize(); i++)
                row.setX(i, vector.dot(mtx.getCol(i)));

            return new ReplicatedVectorMatrix(row, replicationCnt, false);

        }
        else
            throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Vector times(Vector vec) {
        Vector res = vec.like(vec.size());
        if (asCol) {
            for (int i = 0; i < rowSize(); i++)
                res.setX(i, vec.sum() * vector.getX(i));
        }
        else {
            double val = vector.dot(vec);

            for (int i = 0; i < rowSize(); i++)
                res.setX(i, val);
        }
        return res;
    }

    /** {@inheritDoc} */
    @Override public double maxAbsRowSumNorm() {
        return 0;
    }

    /** {@inheritDoc} */
    @Override public double sum() {
        return vector.sum() * replicationCnt;
    }

    /** {@inheritDoc} */
    @Override public Matrix transpose() {
        return new ReplicatedVectorMatrix(vector, replicationCnt, !asCol);
    }

    /** {@inheritDoc} */
    @Override public Matrix viewPart(int[] off, int[] size) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Matrix viewPart(int rowOff, int rows, int colOff, int cols) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Vector viewRow(int row) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Vector viewColumn(int col) {
        return null;
    }

    /** {@inheritDoc} */
    @Override public Vector viewDiagonal() {
        return null;
    }

    /** {@inheritDoc} */
    @Override public void compute(int row, int col, IgniteTriFunction f) {
        // This operation cannot be performed because computing function depends on both indexes and therefore
        // result of compute will be in general case not ReplicatedVectorMatrix.
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public void writeExternal(ObjectOutput out) throws IOException {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
        throw new UnsupportedOperationException();
    }

    /** {@inheritDoc} */
    @Override public Map getMetaStorage() {
        return null;
    }

    /**
     * Returns true if matrix constructed by replicating vector as column and false otherwise.
     */
    public boolean isColumnReplicated() {
        return asCol;
    }

    /**
     * Returns replicated vector.
     */
    public Vector replicant() {
        return vector;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy