All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.dataset.api.DataSet Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 *
 */

package org.nd4j.linalg.dataset.api;

import com.google.common.base.Function;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.indexing.conditions.Condition;

import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/**
 * Created by agibsonccc on 8/26/14.
 */
public interface DataSet extends Iterable, Serializable {


    DataSet getRange(int from, int to);

    /**
     * Load the contents of the DataSet from the specified InputStream. The current contents of the DataSet (if any) will be replaced.
* The InputStream should contain a DataSet that has been serialized with {@link #save(OutputStream)} * * @param from InputStream to load the DataSet from */ void load(InputStream from); /** * Load the contents of the DataSet from the specified File. The current contents of the DataSet (if any) will be replaced.
* The InputStream should contain a DataSet that has been serialized with {@link #save(File)} * * @param from File to load the DataSet from */ void load(File from); /** * Write the contents of this DataSet to the specified OutputStream * * @param to OutputStream to save the DataSet to */ void save(OutputStream to); /** * Save this DataSet to a file. Can be loaded again using {@link } * * @param to File to sa */ void save(File to); @Deprecated DataSetIterator iterateWithMiniBatches(); String id(); /** * Returns the features array for the DataSet * * @return features array */ INDArray getFeatures(); /** * Set the features array for the DataSet * * @param features Features to set */ void setFeatures(INDArray features); /** * Calculate and return a count of each label, by index. * Assumes labels are a one-hot INDArray, for classification * * @return Map of countsn */ Map labelCounts(); void apply(Condition condition, Function function); /** * Create a copy of the DataSet * * @return Copy of the DataSet */ org.nd4j.linalg.dataset.DataSet copy(); org.nd4j.linalg.dataset.DataSet reshape(int rows, int cols); /** * Multiply the features by a scalar */ void multiplyBy(double num); /** * Divide the features by a scalar */ void divideBy(int num); /** * Shuffle the order of the rows in the DataSet. Note that this generally won't make any difference in practice * unless the DataSet is later split. */ void shuffle(); void squishToRange(double min, double max); void scaleMinAndMax(double min, double max); void scale(); void addFeatureVector(INDArray toAdd); void addFeatureVector(INDArray feature, int example); /** * Normalize this DataSet to mean 0, stdev 1 per input. * This calculates statistics based on the values in a single DataSet only. * For normalization over multiple DataSet objects, use {@link org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize} */ void normalize(); void binarize(); void binarize(double cutoff); /** * @deprecated Use {@link org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize} */ @Deprecated void normalizeZeroMeanZeroUnitVariance(); /** * Number of input values - i.e., size of the features INDArray per example */ int numInputs(); void validate(); int outcome(); void setNewNumberOfLabels(int labels); void setOutcome(int example, int label); org.nd4j.linalg.dataset.DataSet get(int i); org.nd4j.linalg.dataset.DataSet get(int[] i); List batchBy(int num); org.nd4j.linalg.dataset.DataSet filterBy(int[] labels); void filterAndStrip(int[] labels); /** * @deprecated prefer {@link #batchBy(int)} */ @Deprecated List dataSetBatches(int num); List sortAndBatchByNumLabels(); List batchByNumLabels(); /** * Extract each example in the DataSet into its own DataSet object, and return all of them as a list * @return List of DataSet objects, each with 1 example only */ List asList(); SplitTestAndTrain splitTestAndTrain(int numHoldout, java.util.Random rnd); SplitTestAndTrain splitTestAndTrain(int numHoldout); INDArray getLabels(); void setLabels(INDArray labels); /** * Equivalent to {@link #getFeatures()} * @deprecated Use {@link #getFeatures()} */ @Deprecated INDArray getFeatureMatrix(); void sortByLabel(); void addRow(org.nd4j.linalg.dataset.DataSet d, int i); INDArray exampleSums(); INDArray exampleMaxs(); INDArray exampleMeans(); org.nd4j.linalg.dataset.DataSet sample(int numSamples); org.nd4j.linalg.dataset.DataSet sample(int numSamples, Random rng); org.nd4j.linalg.dataset.DataSet sample(int numSamples, boolean withReplacement); org.nd4j.linalg.dataset.DataSet sample(int numSamples, Random rng, boolean withReplacement); void roundToTheNearest(int roundTo); /** * Returns the number of outcomes (size of the labels array for each example) */ int numOutcomes(); /** * Number of examples in the DataSet */ int numExamples(); @Deprecated List getLabelNames(); List getLabelNamesList(); String getLabelName(int idx); List getLabelNames(INDArray idxs); void setLabelNames(List labelNames); List getColumnNames(); void setColumnNames(List columnNames); /** * Split the DataSet into two DataSets randomly * @param percentTrain Percentage of examples to be returned in the training DataSet object */ SplitTestAndTrain splitTestAndTrain(double percentTrain); @Override Iterator iterator(); /** * Input mask array: a mask array for input, where each value is in {0,1} in order to specify whether an input is * actually present or not. Typically used for situations such as RNNs with variable length inputs * * @return Input mask array */ INDArray getFeaturesMaskArray(); /** * Set the features mask array in this DataSet */ void setFeaturesMaskArray(INDArray inputMask); /** * Labels (output) mask array: a mask array for input, where each value is in {0,1} in order to specify whether an * output is actually present or not. Typically used for situations such as RNNs with variable length inputs or many- * to-one situations. * * @return Labels (output) mask array */ INDArray getLabelsMaskArray(); /** * Set the labels mask array in this data set */ void setLabelsMaskArray(INDArray labelsMask); /** * Whether the labels or input (features) mask arrays are present for this DataSet */ boolean hasMaskArrays(); /** * Set the metadata for this DataSet
* By convention: the metadata can be any serializable object, one per example in the DataSet * * @param exampleMetaData Example metadata to set */ void setExampleMetaData(List exampleMetaData); /** * Get the example metadata, or null if no metadata has been set
* Note: this method results in an unchecked cast - care should be taken when using this! * * @param metaDataType Class of the metadata (used for type information) * @param Type of metadata * @return List of metadata objects */ List getExampleMetaData(Class metaDataType); /** * Get the example metadata, or null if no metadata has been set * * @return List of metadata instances * @see {@link #getExampleMetaData(Class)} for convenience method for types */ List getExampleMetaData(); }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy