org.apache.mahout.classifier.df.DecisionForest Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.df;
import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.node.Node;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* Represents a forest of decision trees.
*/
@Deprecated
public class DecisionForest implements Writable {
private final List trees;
private DecisionForest() {
trees = new ArrayList<>();
}
public DecisionForest(List trees) {
Preconditions.checkArgument(trees != null && !trees.isEmpty(), "trees argument must not be null or empty");
this.trees = trees;
}
List getTrees() {
return trees;
}
/**
* Classifies the data and calls callback for each classification
*/
public void classify(Data data, double[][] predictions) {
Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()");
if (data.isEmpty()) {
return; // nothing to classify
}
int treeId = 0;
for (Node tree : trees) {
for (int index = 0; index < data.size(); index++) {
if (predictions[index] == null) {
predictions[index] = new double[trees.size()];
}
predictions[index][treeId] = tree.classify(data.get(index));
}
treeId++;
}
}
/**
* predicts the label for the instance
*
* @param rng
* Random number generator, used to break ties randomly
* @return NaN if the label cannot be predicted
*/
public double classify(Dataset dataset, Random rng, Instance instance) {
if (dataset.isNumerical(dataset.getLabelId())) {
double sum = 0;
int cnt = 0;
for (Node tree : trees) {
double prediction = tree.classify(instance);
if (!Double.isNaN(prediction)) {
sum += prediction;
cnt++;
}
}
if (cnt > 0) {
return sum / cnt;
} else {
return Double.NaN;
}
} else {
int[] predictions = new int[dataset.nblabels()];
for (Node tree : trees) {
double prediction = tree.classify(instance);
if (!Double.isNaN(prediction)) {
predictions[(int) prediction]++;
}
}
if (DataUtils.sum(predictions) == 0) {
return Double.NaN; // no prediction available
}
return DataUtils.maxindex(rng, predictions);
}
}
/**
* @return Mean number of nodes per tree
*/
public long meanNbNodes() {
long sum = 0;
for (Node tree : trees) {
sum += tree.nbNodes();
}
return sum / trees.size();
}
/**
* @return Total number of nodes in all the trees
*/
public long nbNodes() {
long sum = 0;
for (Node tree : trees) {
sum += tree.nbNodes();
}
return sum;
}
/**
* @return Mean maximum depth per tree
*/
public long meanMaxDepth() {
long sum = 0;
for (Node tree : trees) {
sum += tree.maxDepth();
}
return sum / trees.size();
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof DecisionForest)) {
return false;
}
DecisionForest rf = (DecisionForest) obj;
return trees.size() == rf.getTrees().size() && trees.containsAll(rf.getTrees());
}
@Override
public int hashCode() {
return trees.hashCode();
}
@Override
public void write(DataOutput dataOutput) throws IOException {
dataOutput.writeInt(trees.size());
for (Node tree : trees) {
tree.write(dataOutput);
}
}
/**
* Reads the trees from the input and adds them to the existing trees
*/
@Override
public void readFields(DataInput dataInput) throws IOException {
int size = dataInput.readInt();
for (int i = 0; i < size; i++) {
trees.add(Node.read(dataInput));
}
}
/**
* Read the forest from inputStream
* @param dataInput - input forest
* @return {@link org.apache.mahout.classifier.df.DecisionForest}
* @throws IOException
*/
public static DecisionForest read(DataInput dataInput) throws IOException {
DecisionForest forest = new DecisionForest();
forest.readFields(dataInput);
return forest;
}
/**
* Load the forest from a single file or a directory of files
* @throws java.io.IOException
*/
public static DecisionForest load(Configuration conf, Path forestPath) throws IOException {
FileSystem fs = forestPath.getFileSystem(conf);
Path[] files;
if (fs.getFileStatus(forestPath).isDir()) {
files = DFUtils.listOutputFiles(fs, forestPath);
} else {
files = new Path[]{forestPath};
}
DecisionForest forest = null;
for (Path path : files) {
try (FSDataInputStream dataInput = new FSDataInputStream(fs.open(path))) {
if (forest == null) {
forest = read(dataInput);
} else {
forest.readFields(dataInput);
}
}
}
return forest;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy