All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.clustering.classify.ClusterClassifier Maven / Gradle / Ivy

/* Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.mahout.clustering.classify;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.clustering.iterator.ClusteringPolicy;
import org.apache.mahout.clustering.iterator.ClusteringPolicyWritable;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/**
 * This classifier works with any ClusteringPolicy and its associated Clusters.
 * It is initialized with a policy and a list of compatible clusters and
 * thereafter it can classify any new Vector into one or more of the clusters
 * based upon the pdf() function which each cluster supports.
 * 

* In addition, it is an OnlineLearner and can be trained. Training amounts to * asking the actual model to observe the vector and closing the classifier * causes all the models to computeParameters. *

* Because a ClusterClassifier implements Writable, it can be written-to and * read-from a sequence file as a single entity. For sequential and MapReduce * clustering in conjunction with a ClusterIterator; however, it utilizes an * exploded file format. In this format, the iterator writes the policy to a * single POLICY_FILE_NAME file in the clustersOut directory and the models are * written to one or more part-n files so that multiple reducers may employed to * produce them. */ public class ClusterClassifier extends AbstractVectorClassifier implements OnlineLearner, Writable { private static final String POLICY_FILE_NAME = "_policy"; private List models; private String modelClass; private ClusteringPolicy policy; /** * The public constructor accepts a list of clusters to become the models * * @param models a List * @param policy a ClusteringPolicy */ public ClusterClassifier(List models, ClusteringPolicy policy) { this.models = models; modelClass = models.get(0).getClass().getName(); this.policy = policy; } // needed for serialization/De-serialization public ClusterClassifier() { } // only used by MR ClusterIterator protected ClusterClassifier(ClusteringPolicy policy) { this.policy = policy; } @Override public Vector classify(Vector instance) { return policy.classify(instance, this); } @Override public double classifyScalar(Vector instance) { if (models.size() == 2) { double pdf0 = models.get(0).pdf(new VectorWritable(instance)); double pdf1 = models.get(1).pdf(new VectorWritable(instance)); return pdf0 / (pdf0 + pdf1); } throw new IllegalStateException(); } @Override public int numCategories() { return models.size(); } @Override public void write(DataOutput out) throws IOException { out.writeInt(models.size()); out.writeUTF(modelClass); new ClusteringPolicyWritable(policy).write(out); for (Cluster cluster : models) { cluster.write(out); } } @Override public void readFields(DataInput in) throws IOException { int size = in.readInt(); modelClass = in.readUTF(); models = new ArrayList<>(); ClusteringPolicyWritable clusteringPolicyWritable = new ClusteringPolicyWritable(); clusteringPolicyWritable.readFields(in); policy = clusteringPolicyWritable.getValue(); for (int i = 0; i < size; i++) { Cluster element = ClassUtils.instantiateAs(modelClass, Cluster.class); element.readFields(in); models.add(element); } } @Override public void train(int actual, Vector instance) { models.get(actual).observe(new VectorWritable(instance)); } /** * Train the models given an additional weight. Unique to ClusterClassifier * * @param actual the int index of a model * @param data a data Vector * @param weight a double weighting factor */ public void train(int actual, Vector data, double weight) { models.get(actual).observe(new VectorWritable(data), weight); } @Override public void train(long trackingKey, String groupKey, int actual, Vector instance) { models.get(actual).observe(new VectorWritable(instance)); } @Override public void train(long trackingKey, int actual, Vector instance) { models.get(actual).observe(new VectorWritable(instance)); } @Override public void close() { policy.close(this); } public List getModels() { return models; } public ClusteringPolicy getPolicy() { return policy; } public void writeToSeqFiles(Path path) throws IOException { writePolicy(policy, path); Configuration config = new Configuration(); FileSystem fs = FileSystem.get(path.toUri(), config); ClusterWritable cw = new ClusterWritable(); for (int i = 0; i < models.size(); i++) { try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, new Path(path, "part-" + String.format(Locale.ENGLISH, "%05d", i)), IntWritable.class, ClusterWritable.class)) { Cluster cluster = models.get(i); cw.setValue(cluster); Writable key = new IntWritable(i); writer.append(key, cw); } } } public void readFromSeqFiles(Configuration conf, Path path) throws IOException { Configuration config = new Configuration(); List clusters = new ArrayList<>(); for (ClusterWritable cw : new SequenceFileDirValueIterable(path, PathType.LIST, PathFilters.logsCRCFilter(), config)) { Cluster cluster = cw.getValue(); cluster.configure(conf); clusters.add(cluster); } this.models = clusters; modelClass = models.get(0).getClass().getName(); this.policy = readPolicy(path); } public static ClusteringPolicy readPolicy(Path path) throws IOException { Path policyPath = new Path(path, POLICY_FILE_NAME); Configuration config = new Configuration(); FileSystem fs = FileSystem.get(policyPath.toUri(), config); SequenceFile.Reader reader = new SequenceFile.Reader(fs, policyPath, config); Text key = new Text(); ClusteringPolicyWritable cpw = new ClusteringPolicyWritable(); reader.next(key, cpw); Closeables.close(reader, true); return cpw.getValue(); } public static void writePolicy(ClusteringPolicy policy, Path path) throws IOException { Path policyPath = new Path(path, POLICY_FILE_NAME); Configuration config = new Configuration(); FileSystem fs = FileSystem.get(policyPath.toUri(), config); SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, policyPath, Text.class, ClusteringPolicyWritable.class); writer.append(new Text(), new ClusteringPolicyWritable(policy)); Closeables.close(writer, false); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy