All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.SegmentedClassifier Maven / Gradle / Ivy

The newest version!
/**
* Global Sensor Networks (GSN) Source Code
* Copyright (c) 2006-2016, Ecole Polytechnique Federale de Lausanne (EPFL)
* 
* This file is part of GSN.
* 
* GSN is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* 
* GSN is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
* 
* You should have received a copy of the GNU General Public License
* along with GSN.  If not, see .
* 
* File: src/weka/classifiers/SegmentedClassifier.java
*
* @author Sofiane Sarni
* @author Julien Eberle
*
*/

package weka.classifiers;

import java.util.Arrays;

import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.ErrorBased;
import weka.filters.unsupervised.instance.SubsetByExpression;

/**
 * A classifier built using several sub-classifier, each one taking care of a segment.
 * The segments are defined by the list of cutting points.
 * It includes the option of filtering the data on each segment to downsample.
 * @author jeberle
 *
 */
public class SegmentedClassifier extends Classifier{

	private static final long serialVersionUID = 2311122072643482718L;
	
	private int segmentedClass = -1;
	
	private Double[] segments ;
	
	private Classifier[] classifiers;
	
	private Filter filter ;

	/**
	 * Initialize a new Classifier
	 * @param c the model of the sub-classifier
	 * @param segClass the feature to segment
	 * @param seg2 the list of cutting points
	 * @param filter the filter to downsample the data
	 * @throws Exception
	 */
	public SegmentedClassifier(Classifier c, int segClass,Double[] seg2,Filter filter) throws Exception {
		this.filter = filter;
		segmentedClass = segClass;
		segments = seg2;
		classifiers = Classifier.makeCopies(c, numSegments());
	}
	
	/**
	 * how many segments do we have?
	 * @return the number of segments
	 */
	private int numSegments(){
		return segments.length + 1;
	}
	
	/**
	 * Get the instances from the dataset, belonging to the given segment and apply the downsampling
	 * @param is the dataset
	 * @param idx the index of the segment to extract
	 * @return the instances in this segment or null if the segment doesn't exist
	 * @throws Exception
	 */
	public Instances getSegment(Instances is, int idx) throws Exception{
		if (idx >= numSegments() || idx < 0 || segmentedClass <= 0 || segmentedClass > is.numAttributes()){
			return null;
		}else{
			
			Filter f = Filter.makeCopy(filter);
			
			if (numSegments() == 1){
				//System.out.println("size before:"+is.numInstances());
				Instances ret = Filter.useFilter(is, f);
				//System.out.println("size after:"+ret.numInstances());
				return ret;}
			SubsetByExpression sbe = new SubsetByExpression();
			sbe.setInputFormat(is);
			String expr = "";
			if (idx == 0){
				expr += "(ATT"+segmentedClass+" < "+segments[idx]+")";
				sbe.setExpression(expr);
				Instances t = Filter.useFilter(is, Filter.makeCopy(sbe));
				if(f instanceof ErrorBased ){
					double[] e = Arrays.copyOfRange(((ErrorBased)f).getM_errors(),0,t.numInstances());
					((ErrorBased) f).setM_errors(e);
				}
			}else if(idx == numSegments()-1){
				expr += "(ATT"+segmentedClass+" >= "+segments[idx-1]+")";
				sbe.setExpression(expr);
				Instances t = Filter.useFilter(is, Filter.makeCopy(sbe));
				if(f instanceof ErrorBased ){
					double[] e = Arrays.copyOfRange(((ErrorBased)f).getM_errors(),((ErrorBased)f).getM_errors().length-t.numInstances(),((ErrorBased)f).getM_errors().length);
					((ErrorBased) f).setM_errors(e);
				}
			}else{
				String expr1 = "(ATT"+segmentedClass+" < "+segments[idx-1]+")";
				SubsetByExpression sbe1 = new SubsetByExpression();
				sbe1.setInputFormat(is);
				sbe1.setExpression(expr1);
				Instances t = Filter.useFilter(is, sbe1);
				expr += "(ATT"+segmentedClass+" >= "+segments[idx-1]+") and (ATT"+segmentedClass+" < "+segments[idx]+")";
				sbe.setExpression(expr);
				Instances tt = Filter.useFilter(is, Filter.makeCopy(sbe));
				if(f instanceof ErrorBased ){
					double[] e = Arrays.copyOfRange(((ErrorBased)f).getM_errors(),t.numInstances(),t.numInstances()+tt.numInstances());
					((ErrorBased) f).setM_errors(e);
				}
			}
			sbe.setExpression(expr);
			Instances t = Filter.useFilter(is, sbe);
			//System.out.println("size before:"+t.numInstances());
			Instances ret = Filter.useFilter(t, f);
			//System.out.println("size after:"+ret.numInstances());
			return ret;
		}
	}
	
	/**
	 * get the index of the segment corresponding to the given instance
	 * @param i the instance
	 * @return
	 */
	private int getSegmentNum(Instance i){
		double value = i.value(segmentedClass-1);
		int idx = 0;
		while (idx < segments.length && segments[idx] < value){
			idx++;
		}
		return idx;
	}
	

	@Override
	public void buildClassifier(Instances data) throws Exception {
		for (int i =0;i




© 2015 - 2024 Weber Informatics LLC | Privacy Policy