All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.spark.util.MLLibUtil Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.util;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import scala.Tuple2;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;


/**
 * Dl4j <----> MLLib
 *
 * @author Adam Gibson
 */
public class MLLibUtil {


    private MLLibUtil() {}

    /**
     * This is for the edge case where
     * you have a single output layer
     * and need to convert the output layer to
     * an index
     * @param vector the vector to get the classifier prediction for
     * @return the prediction for the given vector
     */
    public static double toClassifierPrediction(Vector vector) {
        double max = Double.NEGATIVE_INFINITY;
        int maxIndex = 0;
        for (int i = 0; i < vector.size(); i++) {
            double curr = vector.apply(i);
            if (curr > max) {
                maxIndex = i;
                max = curr;
            }
        }

        return maxIndex;
    }

    /**
     * Convert an ndarray to a matrix.
     * Note that the matrix will be con
     * @param arr the array
     * @return an mllib vector
     */
    public static INDArray toMatrix(Matrix arr) {

        // we assume that Matrix always has F order
        return Nd4j.create(arr.toArray(), new int[] {arr.numRows(), arr.numCols()}, 'f');
    }

    /**
     * Convert an ndarray to a vector
     * @param arr the array
     * @return an mllib vector
     */
    public static INDArray toVector(Vector arr) {
        return Nd4j.create(Nd4j.createBuffer(arr.toArray()));
    }


    /**
     * Convert an ndarray to a matrix.
     * Note that the matrix will be con
     * @param arr the array
     * @return an mllib vector
     */
    public static Matrix toMatrix(INDArray arr) {
        if (!arr.isMatrix()) {
            throw new IllegalArgumentException("passed in array must be a matrix");
        }

        // if arr is a view - we have to dup anyway
        if (arr.isView()) {
            return Matrices.dense(arr.rows(), arr.columns(), arr.dup('f').data().asDouble());
        } else // if not a view - we must ensure data is F ordered
            return Matrices.dense(arr.rows(), arr.columns(),
                            arr.ordering() == 'f' ? arr.data().asDouble() : arr.dup('f').data().asDouble());
    }

    /**
     * Convert an ndarray to a vector
     * @param arr the array
     * @return an mllib vector
     */
    public static Vector toVector(INDArray arr) {
        if (!arr.isVector()) {
            throw new IllegalArgumentException("passed in array must be a vector");
        }
        if (arr.length() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();
        double[] ret = new double[(int) arr.length()];
        for (int i = 0; i < arr.length(); i++) {
            ret[i] = arr.getDouble(i);
        }

        return Vectors.dense(ret);
    }


    /**
     * Convert a traditional sc.binaryFiles
     * in to something usable for machine learning
     * @param binaryFiles the binary files to convert
     * @param reader the reader to use
     * @return the labeled points based on the given rdd
     */
    public static JavaRDD fromBinary(JavaPairRDD binaryFiles,
                    final RecordReader reader) {
        JavaRDD> records =
                        binaryFiles.map(new Function, Collection>() {
                            @Override
                            public Collection call(
                                            Tuple2 stringPortableDataStreamTuple2)
                                            throws Exception {
                                reader.initialize(new InputStreamInputSplit(stringPortableDataStreamTuple2._2().open(),
                                                stringPortableDataStreamTuple2._1()));
                                return reader.next();
                            }
                        });

        JavaRDD ret = records.map(new Function, LabeledPoint>() {
            @Override
            public LabeledPoint call(Collection writables) throws Exception {
                return pointOf(writables);
            }
        });
        return ret;
    }

    /**
     * Convert a traditional sc.binaryFiles
     * in to something usable for machine learning
     * @param binaryFiles the binary files to convert
     * @param reader the reader to use
     * @return the labeled points based on the given rdd
     */
    public static JavaRDD fromBinary(JavaRDD> binaryFiles,
                    final RecordReader reader) {
        return fromBinary(JavaPairRDD.fromJavaRDD(binaryFiles), reader);
    }


    /**
     * Returns a labeled point of the writables
     * where the final item is the point and the rest of the items are
     * features
     * @param writables the writables
     * @return the labeled point
     */
    public static LabeledPoint pointOf(Collection writables) {
        double[] ret = new double[writables.size() - 1];
        int count = 0;
        double target = 0;
        for (Writable w : writables) {
            if (count < writables.size() - 1)
                ret[count++] = Float.parseFloat(w.toString());
            else
                target = Float.parseFloat(w.toString());
        }

        if (target < 0)
            throw new IllegalStateException("Target must be >= 0");
        return new LabeledPoint(target, Vectors.dense(ret));
    }

    /**
     * Convert an rdd
     * of labeled point
     * based on the specified batch size
     * in to data set
     * @param data the data to convert
     * @param numPossibleLabels the number of possible labels
     * @param batchSize the batch size
     * @return the new rdd
     */
    public static JavaRDD fromLabeledPoint(JavaRDD data, final long numPossibleLabels,
                    long batchSize) {

        JavaRDD mappedData = data.map(new Function() {
            @Override
            public DataSet call(LabeledPoint lp) {
                return fromLabeledPoint(lp, numPossibleLabels);
            }
        });

        return mappedData.repartition((int) (mappedData.count() / batchSize));
    }

    /**
     * From labeled point
     * @param sc the org.deeplearning4j.spark context used for creating the rdd
     * @param data the data to convert
     * @param numPossibleLabels the number of possible labels
     * @return
     * @deprecated Use {@link #fromLabeledPoint(JavaRDD, int)}
     */
    @Deprecated
    public static JavaRDD fromLabeledPoint(JavaSparkContext sc, JavaRDD data,
                    final long numPossibleLabels) {
        return data.map(new Function() {
            @Override
            public DataSet call(LabeledPoint lp) {
                return fromLabeledPoint(lp, numPossibleLabels);
            }
        });
    }

    /**
     * Convert rdd labeled points to a rdd dataset with continuous features
     * @param data the java rdd labeled points ready to convert
     * @return a JavaRDD with a continuous label
     * @deprecated Use {@link #fromContinuousLabeledPoint(JavaRDD)}
     */
    @Deprecated
    public static JavaRDD fromContinuousLabeledPoint(JavaSparkContext sc, JavaRDD data) {

        return data.map(new Function() {
            @Override
            public DataSet call(LabeledPoint lp) {
                return convertToDataset(lp);
            }
        });
    }

    private static DataSet convertToDataset(LabeledPoint lp) {
        Vector features = lp.features();
        double label = lp.label();
        return new DataSet(Nd4j.create(features.toArray()), Nd4j.create(new double[] {label}));
    }

    /**
     * Convert an rdd of data set in to labeled point
     * @param sc the spark context to use
     * @param data the dataset to convert
     * @return an rdd of labeled point
     * @deprecated Use {@link #fromDataSet(JavaRDD)}
     *
     */
    @Deprecated
    public static JavaRDD fromDataSet(JavaSparkContext sc, JavaRDD data) {

        return data.map(new Function() {
            @Override
            public LabeledPoint call(DataSet pt) {
                return toLabeledPoint(pt);
            }
        });
    }

    /**
     * Convert a list of dataset in to a list of labeled points
     * @param labeledPoints the labeled points to convert
     * @return the labeled point list
     */
    private static List toLabeledPoint(List labeledPoints) {
        List ret = new ArrayList<>();
        for (DataSet point : labeledPoints) {
            ret.add(toLabeledPoint(point));
        }
        return ret;
    }

    /**
     * Convert a dataset (feature vector) to a labeled point
     * @param point the point to convert
     * @return the labeled point derived from this dataset
     */
    private static LabeledPoint toLabeledPoint(DataSet point) {
        if (!point.getFeatures().isVector()) {
            throw new IllegalArgumentException("Feature matrix must be a vector");
        }

        Vector features = toVector(point.getFeatures().dup());

        double label = Nd4j.getBlasWrapper().iamax(point.getLabels());
        return new LabeledPoint(label, features);
    }

    /**
     * Converts a continuous JavaRDD LabeledPoint to a JavaRDD DataSet.
     * @param data JavaRDD LabeledPoint
     * @return JavaRdd DataSet
     */
    public static JavaRDD fromContinuousLabeledPoint(JavaRDD data) {
        return fromContinuousLabeledPoint(data, false);
    }

    /**
     * Converts a continuous JavaRDD LabeledPoint to a JavaRDD DataSet.
     * @param data JavaRdd LabeledPoint
     * @param preCache boolean pre-cache rdd before operation
     * @return
     */
    public static JavaRDD fromContinuousLabeledPoint(JavaRDD data, boolean preCache) {
        if (preCache && !data.getStorageLevel().useMemory()) {
            data.cache();
        }
        return data.map(new Function() {
            @Override
            public DataSet call(LabeledPoint lp) {
                return convertToDataset(lp);
            }
        });
    }

    /**
     * Converts JavaRDD labeled points to JavaRDD datasets.
     * @param data JavaRDD LabeledPoints
     * @param numPossibleLabels number of possible labels
     * @return
     */
    public static JavaRDD fromLabeledPoint(JavaRDD data, final long numPossibleLabels) {
        return fromLabeledPoint(data, numPossibleLabels, false);
    }

    /**
     * Converts JavaRDD labeled points to JavaRDD DataSets.
     * @param data JavaRDD LabeledPoints
     * @param numPossibleLabels number of possible labels
     * @param preCache boolean pre-cache rdd before operation
     * @return
     */
    public static JavaRDD fromLabeledPoint(JavaRDD data, final long numPossibleLabels,
                    boolean preCache) {
        if (preCache && !data.getStorageLevel().useMemory()) {
            data.cache();
        }
        return data.map(new Function() {
            @Override
            public DataSet call(LabeledPoint lp) {
                return fromLabeledPoint(lp, numPossibleLabels);
            }
        });
    }

    /**
     * Convert an rdd of data set in to labeled point.
     * @param data the dataset to convert
     * @return an rdd of labeled point
     */
    public static JavaRDD fromDataSet(JavaRDD data) {
        return fromDataSet(data, false);
    }

    /**
     * Convert an rdd of data set in to labeled point.
     * @param data the dataset to convert
     * @param preCache boolean pre-cache rdd before operation
     * @return an rdd of labeled point
     */
    public static JavaRDD fromDataSet(JavaRDD data, boolean preCache) {
        if (preCache && !data.getStorageLevel().useMemory()) {
            data.cache();
        }
        return data.map(new Function() {
            @Override
            public LabeledPoint call(DataSet dataSet) {
                return toLabeledPoint(dataSet);
            }
        });
    }


    /**
     *
     * @param labeledPoints
     * @param numPossibleLabels
     * @return List of {@link DataSet}
     */
    private static List fromLabeledPoint(List labeledPoints, long numPossibleLabels) {
        List ret = new ArrayList<>();
        for (LabeledPoint point : labeledPoints) {
            ret.add(fromLabeledPoint(point, numPossibleLabels));
        }
        return ret;
    }

    /**
     *
     * @param point
     * @param numPossibleLabels
     * @return {@link DataSet}
     */
    private static DataSet fromLabeledPoint(LabeledPoint point, long numPossibleLabels) {
        Vector features = point.features();
        double label = point.label();

        // FIXMEL int cast
        double[] fArr = features.toArray();
        return new DataSet(Nd4j.create(fArr, new long[]{1,fArr.length}),
                        FeatureUtil.toOutcomeVector((int) label, (int) numPossibleLabels));
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy