org.datavec.api.transform.reduce.AggregableReductionUtils Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.api.transform.reduce;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.ReduceOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.ops.*;
import org.datavec.api.writable.Writable;
import java.util.ArrayList;
import java.util.List;
/**
* Various utilities for performing reductions
*
* @author Alex Black
*/
public class AggregableReductionUtils {
private AggregableReductionUtils() {}
public static IAggregableReduceOp> reduceColumn(List op, ColumnType type,
boolean ignoreInvalid, ColumnMetaData metaData) {
switch (type) {
case Integer:
return reduceIntColumn(op, ignoreInvalid, metaData);
case Long:
return reduceLongColumn(op, ignoreInvalid, metaData);
case Float:
return reduceFloatColumn(op, ignoreInvalid, metaData);
case Double:
return reduceDoubleColumn(op, ignoreInvalid, metaData);
case String:
case Categorical:
return reduceStringOrCategoricalColumn(op, ignoreInvalid, metaData);
case Time:
return reduceTimeColumn(op, ignoreInvalid, metaData);
case Bytes:
return reduceBytesColumn(op, ignoreInvalid, metaData);
default:
throw new UnsupportedOperationException("Unknown or not implemented column type: " + type);
}
}
public static IAggregableReduceOp> reduceIntColumn(List lop,
boolean ignoreInvalid, ColumnMetaData metaData) {
List> res = new ArrayList<>(lop.size());
for (int i = 0; i < lop.size(); i++) {
switch (lop.get(i)) {
case Prod:
res.add(new AggregatorImpls.AggregableProd());
break;
case Min:
res.add(new AggregatorImpls.AggregableMin());
break;
case Max:
res.add(new AggregatorImpls.AggregableMax());
break;
case Range:
res.add(new AggregatorImpls.AggregableRange());
break;
case Sum:
res.add(new AggregatorImpls.AggregableSum());
break;
case Mean:
res.add(new AggregatorImpls.AggregableMean());
break;
case Stdev:
res.add(new AggregatorImpls.AggregableStdDev());
break;
case UncorrectedStdDev:
res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
break;
case Variance:
res.add(new AggregatorImpls.AggregableVariance());
break;
case PopulationVariance:
res.add(new AggregatorImpls.AggregablePopulationVariance());
break;
case Count:
res.add(new AggregatorImpls.AggregableCount());
break;
case CountUnique:
res.add(new AggregatorImpls.AggregableCountUnique());
break;
case TakeFirst:
res.add(new AggregatorImpls.AggregableFirst());
break;
case TakeLast:
res.add(new AggregatorImpls.AggregableLast());
break;
default:
throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
}
}
IAggregableReduceOp> thisOp = new IntWritableOp<>(new AggregableMultiOp<>(res));
if (ignoreInvalid)
return new AggregableCheckingOp<>(thisOp, metaData);
else
return thisOp;
}
public static IAggregableReduceOp> reduceLongColumn(List lop,
boolean ignoreInvalid, ColumnMetaData metaData) {
List> res = new ArrayList<>(lop.size());
for (int i = 0; i < lop.size(); i++) {
switch (lop.get(i)) {
case Prod:
res.add(new AggregatorImpls.AggregableProd());
break;
case Min:
res.add(new AggregatorImpls.AggregableMin());
break;
case Max:
res.add(new AggregatorImpls.AggregableMax());
break;
case Range:
res.add(new AggregatorImpls.AggregableRange());
break;
case Sum:
res.add(new AggregatorImpls.AggregableSum());
break;
case Stdev:
res.add(new AggregatorImpls.AggregableStdDev());
break;
case UncorrectedStdDev:
res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
break;
case Variance:
res.add(new AggregatorImpls.AggregableVariance());
break;
case PopulationVariance:
res.add(new AggregatorImpls.AggregablePopulationVariance());
break;
case Mean:
res.add(new AggregatorImpls.AggregableMean());
break;
case Count:
res.add(new AggregatorImpls.AggregableCount());
break;
case CountUnique:
res.add(new AggregatorImpls.AggregableCountUnique());
break;
case TakeFirst:
res.add(new AggregatorImpls.AggregableFirst());
break;
case TakeLast:
res.add(new AggregatorImpls.AggregableLast());
break;
default:
throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
}
}
IAggregableReduceOp> thisOp = new LongWritableOp<>(new AggregableMultiOp<>(res));
if (ignoreInvalid)
return new AggregableCheckingOp<>(thisOp, metaData);
else
return thisOp;
}
public static IAggregableReduceOp> reduceFloatColumn(List lop,
boolean ignoreInvalid, ColumnMetaData metaData) {
List> res = new ArrayList<>(lop.size());
for (int i = 0; i < lop.size(); i++) {
switch (lop.get(i)) {
case Prod:
res.add(new AggregatorImpls.AggregableProd());
break;
case Min:
res.add(new AggregatorImpls.AggregableMin());
break;
case Max:
res.add(new AggregatorImpls.AggregableMax());
break;
case Range:
res.add(new AggregatorImpls.AggregableRange());
break;
case Sum:
res.add(new AggregatorImpls.AggregableSum());
break;
case Mean:
res.add(new AggregatorImpls.AggregableMean());
break;
case Stdev:
res.add(new AggregatorImpls.AggregableStdDev());
break;
case UncorrectedStdDev:
res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
break;
case Variance:
res.add(new AggregatorImpls.AggregableVariance());
break;
case PopulationVariance:
res.add(new AggregatorImpls.AggregablePopulationVariance());
break;
case Count:
res.add(new AggregatorImpls.AggregableCount());
break;
case CountUnique:
res.add(new AggregatorImpls.AggregableCountUnique());
break;
case TakeFirst:
res.add(new AggregatorImpls.AggregableFirst());
break;
case TakeLast:
res.add(new AggregatorImpls.AggregableLast());
break;
default:
throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
}
}
IAggregableReduceOp> thisOp = new FloatWritableOp<>(new AggregableMultiOp<>(res));
if (ignoreInvalid)
return new AggregableCheckingOp<>(thisOp, metaData);
else
return thisOp;
}
public static IAggregableReduceOp> reduceDoubleColumn(List lop,
boolean ignoreInvalid, ColumnMetaData metaData) {
List> res = new ArrayList<>(lop.size());
for (int i = 0; i < lop.size(); i++) {
switch (lop.get(i)) {
case Prod:
res.add(new AggregatorImpls.AggregableProd());
break;
case Min:
res.add(new AggregatorImpls.AggregableMin());
break;
case Max:
res.add(new AggregatorImpls.AggregableMax());
break;
case Range:
res.add(new AggregatorImpls.AggregableRange());
break;
case Sum:
res.add(new AggregatorImpls.AggregableSum());
break;
case Mean:
res.add(new AggregatorImpls.AggregableMean());
break;
case Stdev:
res.add(new AggregatorImpls.AggregableStdDev());
break;
case UncorrectedStdDev:
res.add(new AggregatorImpls.AggregableUncorrectedStdDev());
break;
case Variance:
res.add(new AggregatorImpls.AggregableVariance());
break;
case PopulationVariance:
res.add(new AggregatorImpls.AggregablePopulationVariance());
break;
case Count:
res.add(new AggregatorImpls.AggregableCount());
break;
case CountUnique:
res.add(new AggregatorImpls.AggregableCountUnique());
break;
case TakeFirst:
res.add(new AggregatorImpls.AggregableFirst());
break;
case TakeLast:
res.add(new AggregatorImpls.AggregableLast());
break;
default:
throw new UnsupportedOperationException("Unknown or not implemented op: " + lop.get(i));
}
}
IAggregableReduceOp> thisOp = new DoubleWritableOp<>(new AggregableMultiOp<>(res));
if (ignoreInvalid)
return new AggregableCheckingOp<>(thisOp, metaData);
else
return thisOp;
}
public static IAggregableReduceOp> reduceStringOrCategoricalColumn(List lop,
boolean ignoreInvalid, ColumnMetaData metaData) {
List> res = new ArrayList<>(lop.size());
for (int i = 0; i < lop.size(); i++) {
switch (lop.get(i)) {
case Count:
res.add(new AggregatorImpls.AggregableCount());
break;
case CountUnique:
res.add(new AggregatorImpls.AggregableCountUnique());
break;
case TakeFirst:
res.add(new AggregatorImpls.AggregableFirst());
break;
case TakeLast:
res.add(new AggregatorImpls.AggregableLast());
break;
case Append:
res.add(new StringAggregatorImpls.AggregableStringAppend());
break;
case Prepend:
res.add(new StringAggregatorImpls.AggregableStringPrepend());
break;
default:
throw new UnsupportedOperationException("Cannot execute op \"" + lop.get(i)
+ "\" on String/Categorical column "
+ "(can only perform Append, Prepend, Count, CountUnique, TakeFirst and TakeLast ops on categorical columns)");
}
}
IAggregableReduceOp> thisOp = new StringWritableOp<>(new AggregableMultiOp<>(res));
if (ignoreInvalid)
return new AggregableCheckingOp<>(thisOp, metaData);
else
return thisOp;
}
public static IAggregableReduceOp> reduceTimeColumn(List lop,
boolean ignoreInvalid, ColumnMetaData metaData) {
List> res = new ArrayList<>(lop.size());
for (int i = 0; i < lop.size(); i++) {
switch (lop.get(i)) {
case Min:
res.add(new AggregatorImpls.AggregableMin());
break;
case Max:
res.add(new AggregatorImpls.AggregableMax());
break;
case Range:
res.add(new AggregatorImpls.AggregableRange());
break;
case Mean:
res.add(new AggregatorImpls.AggregableMean());
break;
case Stdev:
res.add(new AggregatorImpls.AggregableStdDev());
break;
case Count:
res.add(new AggregatorImpls.AggregableCount());
break;
case CountUnique:
res.add(new AggregatorImpls.AggregableCountUnique());
break;
case TakeFirst:
res.add(new AggregatorImpls.AggregableFirst());
break;
case TakeLast:
res.add(new AggregatorImpls.AggregableLast());
break;
default:
throw new UnsupportedOperationException(
"Reduction op \"" + lop.get(i) + "\" not supported on time columns");
}
}
IAggregableReduceOp> thisOp = new LongWritableOp<>(new AggregableMultiOp<>(res));
if (ignoreInvalid)
return new AggregableCheckingOp<>(thisOp, metaData);
else
return thisOp;
}
public static IAggregableReduceOp> reduceBytesColumn(List lop,
boolean ignoreInvalid, ColumnMetaData metaData) {
List> res = new ArrayList<>(lop.size());
for (int i = 0; i < lop.size(); i++) {
switch (lop.get(i)) {
case TakeFirst:
res.add(new AggregatorImpls.AggregableFirst());
break;
case TakeLast:
res.add(new AggregatorImpls.AggregableLast());
break;
default:
throw new UnsupportedOperationException("Cannot execute op \"" + lop.get(i) + "\" on Bytes column "
+ "(can only perform TakeFirst and TakeLast ops on bytes columns)");
}
}
IAggregableReduceOp> thisOp = new ByteWritableOp<>(new AggregableMultiOp<>(res));
if (ignoreInvalid)
return new AggregableCheckingOp<>(thisOp, metaData);
else
return thisOp;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy