All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.glm.solvers.OneVsOneGlmClassifier Maven / Gradle / Ivy

package com.github.chen0040.glm.solvers;

import com.github.chen0040.glm.data.BasicDataFrame;
import com.github.chen0040.glm.data.DataFrame;
import com.github.chen0040.glm.data.DataRow;
import com.github.chen0040.glm.utils.TupleTwo;

import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;


/**
 * Created by xschen on 8/20/2015 0020.
 */
public class OneVsOneGlmClassifier {
   protected List> classifiers;
   private double alpha = 0.1;
   private boolean shuffleData = false;
   private List classLabels = new ArrayList<>();
   private Supplier generator = () -> Glm.linear();

   private static String BINARY_LABEL = "success";

   public OneVsOneGlmClassifier(List classLabels){
      this.classLabels.addAll(classLabels);
      classifiers = new ArrayList<>();
   }

   public OneVsOneGlmClassifier(){
      super();
      classifiers = new ArrayList<>();
   }


   public OneVsOneGlmClassifier(Supplier binaryClassifierGenerator) {
      super();
      classifiers = new ArrayList<>();
      this.generator = binaryClassifierGenerator;
   }


   public boolean isShuffleData() {
      return shuffleData;
   }

   public void setShuffleData(boolean shuffleData) {
      this.shuffleData = shuffleData;
   }

   public double getAlpha() {
      return alpha;
   }

   public void setAlpha(double alpha) {
      this.alpha = alpha;
   }

   protected void createClassifiers(DataFrame dataFrame){
      classifiers = new ArrayList<>();

      if(classLabels.size()==0){
         classLabels.addAll(dataFrame.stream().map(DataRow::categoricalTarget).distinct().collect(Collectors.toList()));
      }
      for(int i=0; i < classLabels.size()-1; ++i){
         for(int j=i+1; j < classLabels.size(); ++j) {
            Glm svr1 = createClassifier(classLabels.get(i));
            Glm svr2 = createClassifier(classLabels.get(j));
            classifiers.add(new TupleTwo<>(svr1, svr2));
         }
      }
   }



   protected Glm createClassifier(String classLabel) {
      Glm svr = generator.get();
      svr.setName(classLabel);
      return svr;
   }

   protected double getClassifierScore(DataRow tuple, Glm classifier) {
      return classifier.transform(tuple);
   }

   protected List split(DataFrame dataFrame, int n){
      List miniFrames = new ArrayList<>();

      for(int i=0; i < n; ++i){
         miniFrames.add(new BasicDataFrame());
      }

      int index = 0;
      for(DataRow tuple : dataFrame) {
         int batchIndex = index % n;
         miniFrames.get(batchIndex).addRow(tuple);
         index++;
      }

      return miniFrames;
   }

   protected List remerge(List batches, int k){
      List newBatches = new ArrayList<>();


      for(int i=0; i < batches.size(); ++i){

         DataFrame newBatch = new BasicDataFrame();

         for(int j=0; j < k; ++j){
            int d = (i + j) % batches.size();
            DataFrame batch = batches.get(d);
            for(DataRow tuple : batch){
               newBatch.addRow(tuple.makeCopy());
            }
         }

         newBatches.add(newBatch);
      }
      return newBatches;
   }


   public double transform(DataRow row) {
      String label = classify(row);
      return classLabels.indexOf(label);
   }

   public void fit(DataFrame dataFrame) {

      createClassifiers(dataFrame);

      if(shuffleData) {
         dataFrame.shuffle();
      }

      List batches = split(dataFrame, classifiers.size());

      int k= Math.max(1, (int)alpha * batches.size());
      batches = remerge(batches, k);


      for(int i=0; i < classifiers.size(); ++i){
         TupleTwo pair = classifiers.get(i);
         Glm classifier1 = pair._1();
         Glm classifier2 = pair._2();

         classifier1.fit(createBinaryBatch(batches.get(i), classifier1.getName()));
         classifier2.fit(createBinaryBatch(batches.get(i), classifier2.getName()));
      }

   }

   private DataFrame createBinaryBatch(DataFrame dataFrame, String classLabel){
      DataFrame binaryBatch = new BasicDataFrame();
      for(DataRow row  : dataFrame){
         String label = row.categoricalTarget();
         DataRow rowWithBinaryTargetOutput = row.makeCopy();
         rowWithBinaryTargetOutput.setTargetCell(BINARY_LABEL, label.equals(classLabel) ? 1.0 : 0.0);
         binaryBatch.addRow(rowWithBinaryTargetOutput);
      }
      return binaryBatch;
   }


   public String classify(DataRow row) {

      row = row.makeCopy();
      if(row.getTargetColumnNames().isEmpty()) {
        row.setTargetColumnNames(Collections.singletonList(BINARY_LABEL));
      }

      Map scores = score(row);

      String predicatedClassLabel = null;
      int maxScore = 0;
      for(Map.Entry entry : scores.entrySet()){
         String label = entry.getKey();
         int score = entry.getValue();
         if(score > maxScore){
            maxScore= score;
            predicatedClassLabel = label;
         }
      }

      if(predicatedClassLabel == null) {
         predicatedClassLabel = "NA";
      }

      return predicatedClassLabel;
   }


   public void reset() {
      classifiers.clear();
      classLabels.clear();
   }


   public List getClassLabels() {
      return classLabels;
   }


   public Map score(DataRow row) {

      Map scores = new HashMap<>();

      for(int i=0; i < classifiers.size(); ++i){
         TupleTwo pair = classifiers.get(i);
         Glm classifier1 = pair._1();
         Glm classifier2 = pair._2();

         double score1 = getClassifierScore(row, classifier1);
         double score2 = getClassifierScore(row, classifier2);

         if(score1 == score2) continue;

         String winningLabel;
         if(score1 > score2) {
            winningLabel = classifier1.getName();
         }
         else {
            winningLabel = classifier2.getName();
         }
         if(scores.containsKey(winningLabel)){
            scores.put(winningLabel, scores.get(winningLabel) + 1);
         }else {
            scores.put(winningLabel, 1);
         }
      }

      return scores;
   }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy