All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.clust4j.algo.BaseNeighborsModel Maven / Gradle / Ivy

Go to download

A suite of classification clustering algorithm implementations for Java. A number of partitional, hierarchical and density-based algorithms including DBSCAN, k-Means, k-Medoids, MeanShift, Affinity Propagation, HDBSCAN and more.

The newest version!
/*******************************************************************************
 *    Copyright 2015, 2016 Taylor G Smith
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 *******************************************************************************/
package com.clust4j.algo;

import org.apache.commons.math3.linear.RealMatrix;

import com.clust4j.GlobalState;
import com.clust4j.algo.Neighborhood;
import com.clust4j.except.ModelNotFitException;
import com.clust4j.metrics.pairwise.DistanceMetric;
import com.clust4j.metrics.pairwise.GeometricallySeparable;

abstract public class BaseNeighborsModel extends AbstractClusterer {
	private static final long serialVersionUID = 1054047329248586585L;
	
	public static final NeighborsAlgorithm DEF_ALGO = NeighborsAlgorithm.AUTO;
	public static final int DEF_LEAF_SIZE = 30;
	public static final int DEF_K = 5;
	public static final double DEF_RADIUS = 5.0;
	public final static boolean DUAL_TREE_SEARCH = false;
	public final static boolean SORT = true;

	protected Integer kNeighbors = null;
	protected Double radius = null;
	protected boolean radiusMode;
	protected int leafSize, m;
	protected double[][] fit_X;
	protected NearestNeighborHeapSearch tree;
	protected NeighborsAlgorithm alg;
	
	/** Resultant neighborhood from fit method */
	protected volatile Neighborhood res;

	interface TreeBuilder extends MetricValidator {
		public NearestNeighborHeapSearch buildTree(RealMatrix data, 
				int leafSize, BaseNeighborsModel logger);
	}
	
	public static enum NeighborsAlgorithm implements TreeBuilder {
		AUTO {

			@Override
			public NearestNeighborHeapSearch buildTree(RealMatrix data,
					int leafSize, BaseNeighborsModel logger) {
				
				NeighborsAlgorithm alg = delegateAlgorithm(data);
				return alg.buildTree(data, leafSize, logger);
			}
			
			@Override
			public boolean isValidMetric(GeometricallySeparable geo) {
				throw new UnsupportedOperationException("auto has no metric validity criteria");
			}
			
		}, 
		
		KD_TREE {

			@Override
			public NearestNeighborHeapSearch buildTree(RealMatrix data,
					int leafSize, BaseNeighborsModel logger) {
				logger.alg = this;
				return new KDTree(data, leafSize, handleMetric(this, logger), logger);
			}
			
			@Override
			public boolean isValidMetric(GeometricallySeparable g) {
				return KDTree.VALID_METRICS.contains(g.getClass());
			}
		}, 
		
		BALL_TREE {

			@Override
			public NearestNeighborHeapSearch buildTree(RealMatrix data,
					int leafSize, BaseNeighborsModel logger) {
				logger.alg = this;
				return new BallTree(data, leafSize, handleMetric(this, logger), logger);
			}
			
			@Override
			public boolean isValidMetric(GeometricallySeparable g) {
				return BallTree.VALID_METRICS.contains(g.getClass());
			}
		};
		
		private static NeighborsAlgorithm delegateAlgorithm(RealMatrix arm) {
			int mn = arm.getColumnDimension() * arm.getRowDimension();
			return mn > GlobalState.ParallelismConf.MIN_ELEMENTS ?
				BALL_TREE : KD_TREE;
		}
		
		private static DistanceMetric handleMetric(NeighborsAlgorithm na, BaseNeighborsModel logger) {
			GeometricallySeparable g = logger.dist_metric;
			if(!na.isValidMetric(g)) {
				logger.warn(g.getName()+" is not a valid metric for " + na + ". "
					+ "Falling back to default Euclidean");
				logger.setSeparabilityMetric(DEF_DIST);
			}
			
			return (DistanceMetric) logger.dist_metric;
		}
	}
	
	@Override final public boolean isValidMetric(GeometricallySeparable g) {
		return this.alg.isValidMetric(g);
	}
	
	
	protected BaseNeighborsModel(AbstractClusterer caller, BaseNeighborsPlanner planner) {
		super(caller, planner);
		init(planner);
	}
	
	protected BaseNeighborsModel(RealMatrix data, BaseNeighborsPlanner planner, boolean as_is) {
		super(data, planner, as_is);
		init(planner);
	}
	
	protected BaseNeighborsModel(RealMatrix data, BaseNeighborsPlanner planner) {
		super(data, planner);
		init(planner);
	}
	
	final private void init(BaseNeighborsPlanner planner) {
		this.kNeighbors = planner.getK();
		this.radius = planner.getRadius();
		this.leafSize = planner.getLeafSize();
		
		radiusMode = null != radius;
		
		/*
		if(!(planner.getSep() instanceof DistanceMetric)) {
			warn(planner.getSep() + " not a valid metric for neighbors models. "
				+ "Falling back to default: " + DEF_DIST);
			super.setSeparabilityMetric(DEF_DIST);
		}
		*/
		
		if(leafSize < 1)
			throw new IllegalArgumentException("leafsize must be positive");
		
		/*
		 * Internally handles metric validation...
		 */
		this.tree = planner.getAlgorithm().buildTree(this.data, this.leafSize, this);
		
		// Get the data ref from the tree
		fit_X = tree.getData();
		this.m = fit_X.length;
	}

	abstract public static class BaseNeighborsPlanner 
			extends BaseClustererParameters 
			implements NeighborsClassifierParameters {
		private static final long serialVersionUID = 8356804193088162871L;
		
		protected int leafSize = DEF_LEAF_SIZE;
		protected NeighborsAlgorithm algo = DEF_ALGO;
		
		@Override abstract public T fitNewModel(RealMatrix d);
		abstract public BaseNeighborsPlanner setAlgorithm(NeighborsAlgorithm algo);
		abstract public Integer getK();
		abstract public Double getRadius();
		
		final public int getLeafSize() { return leafSize; }
		final public NeighborsAlgorithm getAlgorithm() { return algo; }
	}
	
	public Neighborhood getNeighbors() {
		if(null == res)
			throw new ModelNotFitException("model not yet fit");
		return res.copy();
	}
	
	/**
	 * A class to query the tree for neighborhoods in parallel
	 * @author Taylor G Smith
	 */
	abstract static class ParallelNeighborhoodSearch extends ParallelChunkingTask {
		private static final long serialVersionUID = -1600812794470325448L;
		
		final BaseNeighborsModel model;
		final double[][] distances;
		final int[][] indices;
		final int lo;
		final int hi;

		public ParallelNeighborhoodSearch(double[][] X, BaseNeighborsModel model) {
			super(X); // this auto-chunks the data
			
			this.model = model;
			this.lo = 0;
			this.hi = strategy.getNumChunks(X);
			
			/*
			 * First get the length...
			 */
			int length = 0;
			for(Chunk c: this.chunks)
				length += c.size();
			
			this.distances = new double[length][];
			this.indices = new int[length][];
		}
		
		public ParallelNeighborhoodSearch(ParallelNeighborhoodSearch task, int lo, int hi) {
			super(task);
			
			this.model = task.model;
			this.lo = lo;
			this.hi = hi;
			this.distances = task.distances;
			this.indices = task.indices;
		}

		@Override
		public Neighborhood reduce(Chunk chunk) {
			Neighborhood n = query(model.tree, chunk.get());
			
			// assign to low index, since that's how we retrieved the chunk...
			final int start = chunk.start , end = start + chunk.size();
			double[][] d = n.getDistances();
			int[][] i = n.getIndices();
			
			// Set the distances and indices in place...
			for(int j = start, idx = 0; j < end; j++, idx++) {
				this.distances[j] = d[idx];
				this.indices[j] = i[idx];
			}
			
			return n;
		}

		@Override
		protected Neighborhood compute() {
			if(hi - lo <= 1) { // generally should equal one...
				return reduce(chunks.get(lo));
			} else {
				int mid = this.lo + (this.hi - this.lo) / 2;
				ParallelNeighborhoodSearch left  = newInstance(this, this.lo, mid);
				ParallelNeighborhoodSearch right = newInstance(this, mid, this.hi);
				
				left.fork();
	            right.compute();
	            left.join();

	            return new Neighborhood(distances, indices);
			}
		}
		
		abstract ParallelNeighborhoodSearch newInstance(ParallelNeighborhoodSearch p, int lo, int hi);
		abstract Neighborhood query(NearestNeighborHeapSearch tree, double[][] X);
	}
	
	
	abstract Neighborhood getNeighbors(RealMatrix matrix);
	@Override abstract protected BaseNeighborsModel fit();
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy