All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair Maven / Gradle / Ivy

Go to download

Apache Ignite® is a Distributed Database For High-Performance Computing With In-Memory Speed.

There is a newer version: 2.15.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.ignite.ml.structures;

import java.io.Serializable;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import java.util.TreeSet;
import org.jetbrains.annotations.NotNull;

/**
 * Class for splitting Labeled Dataset on train and test sets.
 */
public class LabeledDatasetTestTrainPair implements Serializable {
    /** Data to keep train set. */
    private LabeledDataset train;

    /** Data to keep test set. */
    private LabeledDataset test;

    /**
     * Creates two subsets of given dataset.
     * 

* NOTE: This method uses next algorithm with O(n log n) by calculations and O(n) by memory. *

* @param dataset The dataset to split on train and test subsets. * @param testPercentage The percentage of the test subset. */ public LabeledDatasetTestTrainPair(LabeledDataset dataset, double testPercentage) { assert testPercentage > 0.0; assert testPercentage < 1.0; final int datasetSize = dataset.rowSize(); assert datasetSize > 2; final int testSize = (int)Math.floor(datasetSize * testPercentage); final int trainSize = datasetSize - testSize; final TreeSet sortedTestIndices = getSortedIndices(datasetSize, testSize); LabeledVector[] testVectors = new LabeledVector[testSize]; LabeledVector[] trainVectors = new LabeledVector[trainSize]; int datasetCntr = 0; int trainCntr = 0; int testCntr = 0; for (Integer idx: sortedTestIndices){ // guarantee order as iterator testVectors[testCntr] = (LabeledVector)dataset.getRow(idx); testCntr++; for (int i = datasetCntr; i < idx; i++) { trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i); trainCntr++; } datasetCntr = idx + 1; } if(datasetCntr < datasetSize){ for (int i = datasetCntr; i < datasetSize; i++) { trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i); trainCntr++; } } test = new LabeledDataset(testVectors, dataset.colSize()); train = new LabeledDataset(trainVectors, dataset.colSize()); } /** This method generates "random double, integer" pairs, sort them, gets first "testSize" elements and returns appropriate indices */ @NotNull private TreeSet getSortedIndices(int datasetSize, int testSize) { Random rnd = new Random(); TreeMap randomIdxPairs = new TreeMap<>(); for (int i = 0; i < datasetSize; i++) randomIdxPairs.put(rnd.nextDouble(), i); final TreeMap testIdxPairs = randomIdxPairs.entrySet().stream() .limit(testSize) .collect(TreeMap::new, (m, e) -> m.put(e.getKey(), e.getValue()), Map::putAll); return new TreeSet<>(testIdxPairs.values()); } /** * Train subset of the whole dataset. * @return Train subset. */ public LabeledDataset train() { return train; } /** * Test subset of the whole dataset. * @return Test subset. */ public LabeledDataset test() { return test; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy