All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.datavec.api.transform.reduce.AggregableReductionUtils Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *  * Copyright 2016 Skymind, Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 */

package org.datavec.api.transform.reduce;

import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.ReduceOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.ops.*;
import org.datavec.api.writable.Writable;

import java.util.ArrayList;
import java.util.List;

/**
 * Various utilities for performing reductions
 *
 * @author Alex Black
 */
public class AggregableReductionUtils {

    private AggregableReductionUtils() {}


    public static IAggregableReduceOp> reduceColumn(List op, ColumnType type,
                    boolean ignoreInvalid, ColumnMetaData metaData) {
        switch (type) {
            case Integer:
                return reduceIntColumn(op, ignoreInvalid, metaData);
            case Long:
                return reduceLongColumn(op, ignoreInvalid, metaData);
            case Float:
                return reduceFloatColumn(op, ignoreInvalid, metaData);
            case Double:
                return reduceDoubleColumn(op, ignoreInvalid, metaData);
            case String:
            case Categorical:
                return reduceStringOrCategoricalColumn(op, ignoreInvalid, metaData);
            case Time:
                return reduceTimeColumn(op, ignoreInvalid, metaData);
            case Bytes:
                return reduceBytesColumn(op, ignoreInvalid, metaData);
            default:
                throw new UnsupportedOperationException("Unknown or not implemented column type: " + type);
        }
    }

    public static IAggregableReduceOp> reduceIntColumn(List lop,
                    boolean ignoreInvalid, ColumnMetaData metaData) {

        List> res = new ArrayList<>(lop.size());
        for (int i = 0; i < lop.size(); i++) {
            switch (lop.get(i)) {
                case Prod:
                    res.add(new AggregatorImpls.AggregableProd());
                    break;
                case Min:
                    res.add(new AggregatorImpls.AggregableMin());
                    break;
                case Max:
                    res.add(new AggregatorImpls.AggregableMax());
                    break;
                case Range:
                    res.add(new AggregatorImpls.AggregableRange());
                    break;
                case Sum:
                    res.add(new AggregatorImpls.AggregableSum());
                    break;
                case Mean:
                    res.add(new AggregatorImpls.AggregableMean());
                    break;
                case Stdev:
                    res.add(new AggregatorImpls.AggregableStdDev());
                    break;
                case UncorrectedStdDev:
                    res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
                    break;
                case Variance:
                    res.add(new AggregatorImpls.AggregableVariance());
                    break;
                case PopulationVariance:
                    res.add(new AggregatorImpls.AggregablePopulationVariance());
                    break;
                case Count:
                    res.add(new AggregatorImpls.AggregableCount());
                    break;
                case CountUnique:
                    res.add(new AggregatorImpls.AggregableCountUnique());
                    break;
                case TakeFirst:
                    res.add(new AggregatorImpls.AggregableFirst());
                    break;
                case TakeLast:
                    res.add(new AggregatorImpls.AggregableLast());
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
            }
        }
        IAggregableReduceOp> thisOp = new IntWritableOp<>(new AggregableMultiOp<>(res));
        if (ignoreInvalid)
            return new AggregableCheckingOp<>(thisOp, metaData);
        else
            return thisOp;
    }

    public static IAggregableReduceOp> reduceLongColumn(List lop,
                    boolean ignoreInvalid, ColumnMetaData metaData) {

        List> res = new ArrayList<>(lop.size());
        for (int i = 0; i < lop.size(); i++) {
            switch (lop.get(i)) {
                case Prod:
                    res.add(new AggregatorImpls.AggregableProd());
                    break;
                case Min:
                    res.add(new AggregatorImpls.AggregableMin());
                    break;
                case Max:
                    res.add(new AggregatorImpls.AggregableMax());
                    break;
                case Range:
                    res.add(new AggregatorImpls.AggregableRange());
                    break;
                case Sum:
                    res.add(new AggregatorImpls.AggregableSum());
                    break;
                case Stdev:
                    res.add(new AggregatorImpls.AggregableStdDev());
                    break;
                case UncorrectedStdDev:
                    res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
                    break;
                case Variance:
                    res.add(new AggregatorImpls.AggregableVariance());
                    break;
                case PopulationVariance:
                    res.add(new AggregatorImpls.AggregablePopulationVariance());
                    break;
                case Mean:
                    res.add(new AggregatorImpls.AggregableMean());
                    break;
                case Count:
                    res.add(new AggregatorImpls.AggregableCount());
                    break;
                case CountUnique:
                    res.add(new AggregatorImpls.AggregableCountUnique());
                    break;
                case TakeFirst:
                    res.add(new AggregatorImpls.AggregableFirst());
                    break;
                case TakeLast:
                    res.add(new AggregatorImpls.AggregableLast());
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
            }
        }
        IAggregableReduceOp> thisOp = new LongWritableOp<>(new AggregableMultiOp<>(res));
        if (ignoreInvalid)
            return new AggregableCheckingOp<>(thisOp, metaData);
        else
            return thisOp;
    }

    public static IAggregableReduceOp> reduceFloatColumn(List lop,
                    boolean ignoreInvalid, ColumnMetaData metaData) {

        List> res = new ArrayList<>(lop.size());
        for (int i = 0; i < lop.size(); i++) {
            switch (lop.get(i)) {
                case Prod:
                    res.add(new AggregatorImpls.AggregableProd());
                    break;
                case Min:
                    res.add(new AggregatorImpls.AggregableMin());
                    break;
                case Max:
                    res.add(new AggregatorImpls.AggregableMax());
                    break;
                case Range:
                    res.add(new AggregatorImpls.AggregableRange());
                    break;
                case Sum:
                    res.add(new AggregatorImpls.AggregableSum());
                    break;
                case Mean:
                    res.add(new AggregatorImpls.AggregableMean());
                    break;
                case Stdev:
                    res.add(new AggregatorImpls.AggregableStdDev());
                    break;
                case UncorrectedStdDev:
                    res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
                    break;
                case Variance:
                    res.add(new AggregatorImpls.AggregableVariance());
                    break;
                case PopulationVariance:
                    res.add(new AggregatorImpls.AggregablePopulationVariance());
                    break;
                case Count:
                    res.add(new AggregatorImpls.AggregableCount());
                    break;
                case CountUnique:
                    res.add(new AggregatorImpls.AggregableCountUnique());
                    break;
                case TakeFirst:
                    res.add(new AggregatorImpls.AggregableFirst());
                    break;
                case TakeLast:
                    res.add(new AggregatorImpls.AggregableLast());
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
            }
        }
        IAggregableReduceOp> thisOp = new FloatWritableOp<>(new AggregableMultiOp<>(res));
        if (ignoreInvalid)
            return new AggregableCheckingOp<>(thisOp, metaData);
        else
            return thisOp;
    }

    public static IAggregableReduceOp> reduceDoubleColumn(List lop,
                    boolean ignoreInvalid, ColumnMetaData metaData) {

        List> res = new ArrayList<>(lop.size());
        for (int i = 0; i < lop.size(); i++) {
            switch (lop.get(i)) {
                case Prod:
                    res.add(new AggregatorImpls.AggregableProd());
                    break;
                case Min:
                    res.add(new AggregatorImpls.AggregableMin());
                    break;
                case Max:
                    res.add(new AggregatorImpls.AggregableMax());
                    break;
                case Range:
                    res.add(new AggregatorImpls.AggregableRange());
                    break;
                case Sum:
                    res.add(new AggregatorImpls.AggregableSum());
                    break;
                case Mean:
                    res.add(new AggregatorImpls.AggregableMean());
                    break;
                case Stdev:
                    res.add(new AggregatorImpls.AggregableStdDev());
                    break;
                case UncorrectedStdDev:
                    res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
                    break;
                case Variance:
                    res.add(new AggregatorImpls.AggregableVariance());
                    break;
                case PopulationVariance:
                    res.add(new AggregatorImpls.AggregablePopulationVariance());
                    break;
                case Count:
                    res.add(new AggregatorImpls.AggregableCount());
                    break;
                case CountUnique:
                    res.add(new AggregatorImpls.AggregableCountUnique());
                    break;
                case TakeFirst:
                    res.add(new AggregatorImpls.AggregableFirst());
                    break;
                case TakeLast:
                    res.add(new AggregatorImpls.AggregableLast());
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
            }
        }
        IAggregableReduceOp> thisOp = new DoubleWritableOp<>(new AggregableMultiOp<>(res));
        if (ignoreInvalid)
            return new AggregableCheckingOp<>(thisOp, metaData);
        else
            return thisOp;
    }

    public static IAggregableReduceOp> reduceStringOrCategoricalColumn(List lop,
                    boolean ignoreInvalid, ColumnMetaData metaData) {

        List> res = new ArrayList<>(lop.size());
        for (int i = 0; i < lop.size(); i++) {
            switch (lop.get(i)) {
                case Count:
                    res.add(new AggregatorImpls.AggregableCount());
                    break;
                case CountUnique:
                    res.add(new AggregatorImpls.AggregableCountUnique());
                    break;
                case TakeFirst:
                    res.add(new AggregatorImpls.AggregableFirst());
                    break;
                case TakeLast:
                    res.add(new AggregatorImpls.AggregableLast());
                    break;
                case Append:
                    res.add(new StringAggregatorImpls.AggregableStringAppend());
                    break;
                case Prepend:
                    res.add(new StringAggregatorImpls.AggregableStringPrepend());
                    break;
                default:
                    throw new UnsupportedOperationException("Cannot execute op \"" + lop.get(i)
                                    + "\" on String/Categorical column "
                                    + "(can only perform Append, Prepend, Count, CountUnique, TakeFirst and TakeLast ops on categorical columns)");
            }
        }

        IAggregableReduceOp> thisOp = new StringWritableOp<>(new AggregableMultiOp<>(res));
        if (ignoreInvalid)
            return new AggregableCheckingOp<>(thisOp, metaData);
        else
            return thisOp;
    }

    public static IAggregableReduceOp> reduceTimeColumn(List lop,
                    boolean ignoreInvalid, ColumnMetaData metaData) {

        List> res = new ArrayList<>(lop.size());
        for (int i = 0; i < lop.size(); i++) {
            switch (lop.get(i)) {
                case Min:
                    res.add(new AggregatorImpls.AggregableMin());
                    break;
                case Max:
                    res.add(new AggregatorImpls.AggregableMax());
                    break;
                case Range:
                    res.add(new AggregatorImpls.AggregableRange());
                    break;
                case Mean:
                    res.add(new AggregatorImpls.AggregableMean());
                    break;
                case Stdev:
                    res.add(new AggregatorImpls.AggregableStdDev());
                    break;
                case Count:
                    res.add(new AggregatorImpls.AggregableCount());
                    break;
                case CountUnique:
                    res.add(new AggregatorImpls.AggregableCountUnique());
                    break;
                case TakeFirst:
                    res.add(new AggregatorImpls.AggregableFirst());
                    break;
                case TakeLast:
                    res.add(new AggregatorImpls.AggregableLast());
                    break;
                default:
                    throw new UnsupportedOperationException(
                                    "Reduction op \"" + lop.get(i) + "\" not supported on time columns");
            }
        }
        IAggregableReduceOp> thisOp = new LongWritableOp<>(new AggregableMultiOp<>(res));
        if (ignoreInvalid)
            return new AggregableCheckingOp<>(thisOp, metaData);
        else
            return thisOp;
    }

    public static IAggregableReduceOp> reduceBytesColumn(List lop,
                    boolean ignoreInvalid, ColumnMetaData metaData) {

        List> res = new ArrayList<>(lop.size());
        for (int i = 0; i < lop.size(); i++) {
            switch (lop.get(i)) {
                case TakeFirst:
                    res.add(new AggregatorImpls.AggregableFirst());
                    break;
                case TakeLast:
                    res.add(new AggregatorImpls.AggregableLast());
                    break;
                default:
                    throw new UnsupportedOperationException("Cannot execute op \"" + lop.get(i) + "\" on Bytes column "
                                    + "(can only perform TakeFirst and TakeLast ops on bytes columns)");
            }
        }
        IAggregableReduceOp> thisOp = new ByteWritableOp<>(new AggregableMultiOp<>(res));
        if (ignoreInvalid)
            return new AggregableCheckingOp<>(thisOp, metaData);
        else
            return thisOp;
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy