All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.distributed.KMeansMapTask Maven / Gradle / Ivy

Go to download

This package provides generic configuration class and distributed map/reduce style tasks for Weka

There is a newer version: 1.0.17
Show newest version
/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */

/*
 *    KMeansMapTask
 *    Copyright (C) 2014 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.distributed;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Vector;

import weka.classifiers.rules.DecisionTableHashKey;
import weka.clusterers.ClusterUtils;
import weka.clusterers.PreconstructedKMeans;
import weka.core.Attribute;
import weka.core.EuclideanDistance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.NormalizableDistance;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.StreamableFilterHelper;
import weka.core.Utils;
import weka.core.stats.NumericStats;
import weka.core.stats.Stats;
import weka.filters.Filter;
import weka.filters.PreconstructedFilter;
import weka.filters.StreamableFilter;
import weka.filters.unsupervised.attribute.PreconstructedMissingValuesReplacer;
import distributed.core.DistributedJobConfig;

/**
 * Map task for k-means clustering. Uses a "pre-constructed" KMeans cluster
 * internally to perform the clustering (i.e. assigning training points to
 * clusters). This is constructed with the centroids found in the previous
 * iteration. Maintains (partial) summary stats on each centroid (by re-using
 * the ARFF header summary attributes mechanism). Can use an arbitrary number of
 * Streamable filters for preprocessing the data on the fly.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: $
 */
public class KMeansMapTask implements OptionHandler, Serializable {

  /**
   * For serialization
   */
  private static final long serialVersionUID = -2423639970668815722L;

  protected Instances m_headerWithSummary;

  /**
   * The header to use after it has been through any filters that the user has
   * specified
   */
  protected Instances m_transformedHeaderNoSummary;

  /** The list of filters to use */
  protected List m_filtersToUse = new ArrayList();

  /** The missing values replacer to use */
  protected PreconstructedFilter m_missingValuesReplacer;

  protected boolean m_dontReplaceMissing;

  /**
   * The final pre-processing filter to use (encapsulating all specified filters
   * and the missing values replacer)
   */
  protected PreconstructedFilter m_finalFullPreprocess;

  protected NormalizableDistance m_distanceFunction = new EuclideanDistance();

  /** The current centroids */
  protected Instances m_centroids;

  /** The partial stats for each centroid */
  protected List> m_centroidSummaryStats;

  /** The KMeans instance to use for clustering the training data */
  protected PreconstructedKMeans m_kMeans;

  /**
   * True if the data goes through more than just the missing values replacement
   */
  protected boolean m_dataIsBeingTransformed;

  /**
   * Whether we should update the distance function (i.e to update range info)
   * with each incoming instance. If no filters (apart from missing) are
   * involved then we don't need to update (as the dummy priming data computed
   * in init() contains global min/max). If we are using filters, then the first
   * k-means iteration will need to update. However, if using filters, after the
   * first iteration the client should call setDummyDistancePrimingData() with
   * the priming data computed by the reduce task (at this point, all
   * transformed instances have been seen and the partial summary metadata for
   * the clusters can be used to compute global min/max in the transformed
   * space). This priming data should be retained somewhere and used in further
   * iterations if the distributed platform requires instantiation of new
   * map/reduce task objects for each pass over the data.
   */
  protected boolean m_updateDistanceFunction;

  /** Client will set this to true once the convergence criteria has been met */
  protected boolean m_converged;

  /**
   * Initilizes the map task. Configures any filters required.
   *
   * @param headerWithSummary header of the incoming instances with summary
   *          attributes included
   * @return the header (without summary attributes) after it has been through
   *         any filters that the user may have specified. This structure is
   *         needed by the KMeansReduceTask
   * @throws DistributedWekaException
   */
  public Instances init(Instances headerWithSummary)
    throws DistributedWekaException {
    // to be called after setOptions();

    m_headerWithSummary = headerWithSummary;

    m_transformedHeaderNoSummary =
      CSVToARFFHeaderReduceTask.stripSummaryAtts(m_headerWithSummary);
    Instances dummyDistancePrimer =
      ClusterUtils.getPrimingDataForDistanceFunction(m_headerWithSummary);

    // deal with filters
    if (!m_dontReplaceMissing) {
      try {
        m_missingValuesReplacer =
          new PreconstructedMissingValuesReplacer(m_headerWithSummary);
      } catch (Exception ex) {
        throw new DistributedWekaException(ex);
      }
    }

    configureFilters(m_transformedHeaderNoSummary);

    if (!m_dataIsBeingTransformed) {
      m_distanceFunction.setInstances(dummyDistancePrimer);
    } else {
      m_distanceFunction.setInstances(((Filter) m_finalFullPreprocess)
        .getOutputFormat());
      m_updateDistanceFunction = true;
    }

    m_kMeans = new PreconstructedKMeans();
    try {
      m_kMeans.setDistanceFunction(m_distanceFunction);
    } catch (Exception ex) {
      throw new DistributedWekaException(ex);
    }

    // apply any filters to our header (sans summary atts)
    try {
      m_transformedHeaderNoSummary = applyFilters(m_transformedHeaderNoSummary);
    } catch (Exception e) {
      throw new DistributedWekaException(e);
    }

    return m_transformedHeaderNoSummary;
  }

  /**
   * Set the dummy priming data (two-instance dataset that contains global
   * min/max for numeric attributes) for the distance function to use when
   * normalizing numeric attributes. This method should be called when filters
   * that transform the data are being used, and *after* the first iteration of
   * k-means has completed. At this point, the reduce task can compute global
   * min/max for transformed attributes using the partial summary metadata for
   * the clusters computed in the first iteration
   *
   * @param priming the dummy priming data to use in the distance function
   */
  public void setDummyDistancePrimingData(Instances priming)
    throws DistributedWekaException {
    if (m_kMeans == null) {
      throw new DistributedWekaException("Must call init() first");
    }

    m_distanceFunction = new EuclideanDistance();
    m_distanceFunction.setInstances(priming);
    try {
      m_kMeans.setDistanceFunction(m_distanceFunction);
      m_updateDistanceFunction = false;
    } catch (Exception ex) {
      throw new DistributedWekaException(ex);
    }
  }

  /**
   * Processes a training instance. Uses the k-means clusterer to find the
   * nearest centroid to the supplied instance and then updates the summary meta
   * data header for the corresponding centroid with the training instance.
   *
   * @param toProcess the instance to process
   * @throws DistributedWekaException if a problem occurs
   */
  public void processInstance(Instance toProcess)
    throws DistributedWekaException {
    if (m_centroids == null) {
      throw new DistributedWekaException("No centroids set!");
    }

    try {
      Instance filteredInstance = applyFilters(toProcess);

      int bestCluster =
        m_kMeans.clusterProcessedInstance(/** (Filter) m_finalFullPreprocess */
        null,
          filteredInstance, m_updateDistanceFunction, null);

      Map summaryStats = m_centroidSummaryStats.get(bestCluster);
      for (int i = 0; i < m_transformedHeaderNoSummary.numAttributes(); i++) {
        if (m_transformedHeaderNoSummary.attribute(i).isNominal()
          || m_transformedHeaderNoSummary.attribute(i).isNumeric()) {

          boolean isNominal = filteredInstance.attribute(i).isNominal();
          CSVToARFFHeaderMapTask
            .updateSummaryStats(
              summaryStats,
              m_transformedHeaderNoSummary
                .attribute(i)
                .name(),
              !filteredInstance.isMissing(i) && isNominal ? 1.0
                : filteredInstance.value(i),
              !filteredInstance.isMissing(i) && isNominal ? filteredInstance
                .stringValue(i)
                : null, isNominal, false,
              false, false, NumericStats.Q_COMPRESSION);
        }
      }
    } catch (Exception ex) {
      throw new DistributedWekaException(ex);
    }
  }

  /**
   * Computes the distance between the two supplied instances
   *
   * @param one the first instance
   * @param two the second instance
   * @return the distance between the two
   * @throws DistributedWekaException if a problem occurs
   */
  public double distance(Instance one, Instance two)
    throws DistributedWekaException {
    if (m_kMeans == null) {
      throw new DistributedWekaException(
        "[KMeansMapTask] We haven't been initialized yet");
    }

    return m_kMeans.getDistanceFunction().distance(one, two);
  }

  /**
   * Get the summary stats for each centroid
   *
   * @return the summary stats (one instances object) for each centroid
   */
  public List getCentroidStats() {

    double[] clusterErrors = m_kMeans.getErrorsForClusters();

    List centerStats = new ArrayList();
    for (int i = 0; i < m_centroids.numInstances(); i++) {
      Map centroidStats = m_centroidSummaryStats.get(i);
      ArrayList atts = new ArrayList();

      for (int j = 0; j < m_transformedHeaderNoSummary.numAttributes(); j++) {
        atts.add((Attribute) m_transformedHeaderNoSummary.attribute(j).copy());
      }

      boolean ok = true;
      for (int j = 0; j < m_transformedHeaderNoSummary.numAttributes(); j++) {
        if (m_transformedHeaderNoSummary.attribute(j).isNominal()
          || m_transformedHeaderNoSummary.attribute(j).isNumeric()) {
          Stats s =
            centroidStats.get(m_transformedHeaderNoSummary.attribute(j).name());
          // if any stats attribute is null then it means no instances
          // were assigned to this centroid
          if (s == null) {
            System.err.println("No instances for centroid: " + i + " "
              + m_centroids.instance(i));
            ok = false;
            break;
          }
          atts.add(s.makeAttribute());
        }
      }

      if (ok) {
        // we embed the partial error for each cluster in the relation name
        Instances cStats =
          new Instances("Partial stats for centroid " + i + " : "
            + clusterErrors[i], atts, 0);
        centerStats.add(cStats);
      } else {
        centerStats.add(null);
      }
    }

    return centerStats;
  }

  /**
   * Configures filters to use when clustering
   *
   * @param headerNoSummary the header of the training data sans summary
   *          attributes
   * @throws DistributedWekaException if a problem occurs
   */
  protected void configureFilters(Instances headerNoSummary)
    throws DistributedWekaException {
    // setOptions() will have set up the pre-processing filters. Now
    // we just adjust the final set depending on whether missing values
    // are to be replaced as well. We always want missing values first
    // in the list so that it processes the original data
    List filters = new ArrayList();
    if (!getDontReplaceMissingValues()) {
      filters.add((StreamableFilter) m_missingValuesReplacer);
    }
    if (m_filtersToUse != null && m_filtersToUse.size() > 0) {
      for (Filter f : m_filtersToUse) {
        if (!(f instanceof StreamableFilter)) {
          throw new DistributedWekaException("Filter " + f.getClass().getName()
            + " is not a StreamableFilter!");
        }

        filters.add((StreamableFilter) f);
      }
    }

    if (filters.size() > 0) {
      try {
        m_finalFullPreprocess =
          StreamableFilterHelper.wrapStreamableFilters(filters);

        if (filters.size() > 1) {
          m_dataIsBeingTransformed = true;
        }
      } catch (Exception e) {
        throw new DistributedWekaException(e);
      }
    }

    if (m_finalFullPreprocess != null) {
      try {
        ((Filter) m_finalFullPreprocess).setInputFormat(headerNoSummary);
      } catch (Exception e) {
        throw new DistributedWekaException(e);
      }
    }
  }

  /**
   * Returns the tip text for this property.
   *
   * @return tip text for this property suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String dontReplaceMissingValuesTipText() {
    return "Don't replace missing values globally with mean/mode.";
  }

  /**
   * Sets whether missing values are to be replaced.
   *
   * @param r true if missing values are to be replaced
   */
  public void setDontReplaceMissingValues(boolean r) {
    m_dontReplaceMissing = r;
  }

  /**
   * Gets whether missing values are to be replaced.
   * 
   * @return true if missing values are to be replaced
   */
  public boolean getDontReplaceMissingValues() {
    return m_dontReplaceMissing;
  }

  /**
   * Get the distance function in use
   *
   * @return the distance function in use
   */
  public NormalizableDistance getDistanceFunction() {
    return m_distanceFunction;
  }

  /**
   * Set the cluster centroids to use for this iteration. NOTE: These should be
   * in the transformed space if any filters (including missing values
   * replacement) are being used.
   *
   * @param centers the centroids to use
   */
  public void setCentroids(Instances centers) {
    m_centroids = centers;
    m_kMeans.setClusterCentroids(centers);

    m_centroidSummaryStats = new ArrayList>();
    for (int i = 0; i < centers.numInstances(); i++) {
      m_centroidSummaryStats.add(new HashMap());
    }
  }

  /**
   * Apply the filters (if any) setup for this map task to the supplied
   * instances
   *
   * @param toApplyTo the instances to filer
   * @return a filtered set of instances
   * @throws Exception if a problem occurs
   */
  public Instances applyFilters(Instances toApplyTo) throws Exception {
    Instances result = toApplyTo;
    if (m_finalFullPreprocess != null) {
      result =
        new Instances(((Filter) m_finalFullPreprocess).getOutputFormat(), 0);
      for (int i = 0; i < toApplyTo.numInstances(); i++) {
        ((Filter) m_finalFullPreprocess).input(toApplyTo.instance(i));
        Instance processed = ((Filter) m_finalFullPreprocess).output();
        result.add(processed);
      }
    }

    return result;
  }

  /**
   * Apply the filters (if any) for this map task to the supplied instance
   *
   * @param original the instance in the original space
   * @return a filtered instance
   * @throws Exception if a problem occurs
   */
  public Instance applyFilters(Instance original) throws Exception {
    Instance result = original;

    if (m_finalFullPreprocess != null) {
      ((Filter) m_finalFullPreprocess).input(result);
      result = ((Filter) m_finalFullPreprocess).output();
    }

    return result;
  }

  /**
   * Gets the full set of preprocessing filters
   *
   * @return preprocessing filter(s) or null if no preprocessing/missing values
   *         handling is being done
   */
  public Filter getPreprocessingFilters() {
    return (Filter) m_finalFullPreprocess;
  }

  /**
   * Set whether the run of k-means that this map is associated with has
   * converged or not
   *
   * @param converged true if the run has converged
   */
  public void setConverged(boolean converged) {
    m_converged = converged;
  }

  /**
   * Get whether the run of k-means that this map tasks is associated with has
   * converged
   * 
   * @return true if the run has converged
   */
  public boolean getConverged() {
    return m_converged;
  }

  /**
   * Get the header of the data after it has been through any pre-processing
   * filters specified by the user
   *
   * @return the transformed header
   */
  public Instances getTransformedHeader() {
    return m_transformedHeaderNoSummary;
  }

  /**
   * Get the user-specified filters to use with the k-means clusterer. Does not
   * include the missing values replacement filter that is automatically
   * configured using global ARFF profiling summary data
   *
   * @return the user-specified filters to use with k-means
   */
  public Filter[] getFiltersToUse() {
    List finalList = new ArrayList();
    for (Filter f : m_filtersToUse) {
      if (!(f instanceof PreconstructedFilter)) {
        finalList.add(f);
      }
    }

    Filter[] result = new Filter[finalList.size()];
    int count = 0;
    for (Filter f : m_filtersToUse) {
      if (!(f instanceof PreconstructedFilter)) {
        result[count++] = f;
      }
    }

    return result;
  }

  /**
   * Set the user-specified filters to use with the k-means clusterer. Does not
   * include the missing values replacement filter that is automatically
   * configured using global ARFF profiling summary data
   *
   * @param toUse the user-specified filters to use with k-means
   */
  public void setFiltersToUse(Filter[] toUse) {
    m_filtersToUse.clear();

    if (toUse != null && toUse.length > 0) {
      for (Filter f : toUse) {
        if (!(f instanceof PreconstructedFilter)
          && f instanceof StreamableFilter) {
          m_filtersToUse.add(f);
        }
      }
    }
  }

  /**
   * The tool tip text for this property.
   *
   * @return the tool tip text for this property
   */
  public String filtersToUseTipText() {
    return "Filters to pre-process the data with before "
      + "passing it to k-means. Note that only StreamableFilters can be used.";
  }

  @Override
  public Enumeration




© 2015 - 2025 Weber Informatics LLC | Privacy Policy