All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.davidbracewell.apollo.ml.clustering.SilhouetteEvaluation Maven / Gradle / Ivy

The newest version!
/*
 * (c) 2005 David B. Bracewell
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.davidbracewell.apollo.ml.clustering;

import com.davidbracewell.Math2;
import com.davidbracewell.apollo.linear.NDArray;
import com.davidbracewell.apollo.ml.Evaluation;
import com.davidbracewell.apollo.ml.Instance;
import com.davidbracewell.apollo.ml.data.Dataset;
import com.davidbracewell.apollo.stat.measure.Measure;
import com.davidbracewell.stream.StreamingContext;
import com.davidbracewell.string.TableFormatter;
import lombok.NonNull;

import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

import static com.davidbracewell.tuple.Tuples.$;

/**
 * @author David B. Bracewell
 */
public class SilhouetteEvaluation implements Evaluation {
   double avgSilhouette = 0;
   Map silhouette;

   @Override
   public void evaluate(@NonNull Clustering model, Dataset dataset) {
      evaluate(model);
   }

   @Override
   public void evaluate(@NonNull Clustering model, Collection dataset) {
      evaluate(model);
   }

   public void evaluate(@NonNull Clustering model) {
      Map idClusterMap = new HashMap<>();
      model.forEach(c -> idClusterMap.put(c.getId(), c));
      silhouette = StreamingContext.local().stream(idClusterMap.keySet())
                                   .parallel()
                                   .mapToPair(i -> $(i, silhouette(idClusterMap, i, model.getMeasure())))
                                   .collectAsMap();
      avgSilhouette = Math2.summaryStatistics(silhouette.values()).getAverage();
   }

   public double getAvgSilhouette() {
      return avgSilhouette;
   }

   public double getSilhoette(int id) {
      return silhouette.get(id);
   }

   @Override
   public void merge(@NonNull Evaluation evaluation) {
      throw new UnsupportedOperationException();
   }

   @Override
   public void output(@NonNull PrintStream printStream) {
      TableFormatter formatter = new TableFormatter();
      formatter.title("Silhouette Cluster Evaluation");
      formatter.header(Arrays.asList("Cluster", "Silhouette Score"));
      silhouette.keySet()
                .stream()
                .sorted()
                .forEach(id -> formatter.content(Arrays.asList(id, silhouette.get(id))));
      formatter.footer(Arrays.asList("Avg. Score", avgSilhouette));
      formatter.print(printStream);
   }

   public void reset() {
      this.avgSilhouette = 0;
      this.silhouette.clear();
   }

   public double silhouette(Map clusters, int index, Measure distanceMeasure) {
      Cluster c1 = clusters.get(index);
      if (c1.size() <= 1) {
         return 0;
      }

      double s = 0;
      for (NDArray point1 : c1) {
         double ai = 0;
         for (NDArray point2 : c1) {
            ai += distanceMeasure.calculate(point1, point2);
         }
         ai /= c1.size();
         double bi = clusters.keySet().parallelStream()
                             .filter(j -> j != index)
                             .mapToDouble(j -> {
                                double b = 0;
                                for (NDArray point2 : clusters.get(j)) {
                                   b += distanceMeasure.calculate(point1, point2);
                                }
                                return b;
                             }).min().orElseThrow(NullPointerException::new);
         s += (bi - ai) / Math.max(bi, ai);
      }

      return s / c1.size();
   }

}//END OF SilhouetteEvaluation




© 2015 - 2025 Weber Informatics LLC | Privacy Policy