All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.glm.evaluators.ConfusionMatrix Maven / Gradle / Ivy

package com.github.chen0040.glm.evaluators;

import com.github.chen0040.glm.utils.TupleTwo;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;


/**
 * Created by xschen on 11/16/16.
 */
public class ConfusionMatrix implements Serializable {
   private static final long serialVersionUID = 8446651320939507735L;
   private Map, Integer> matrix = new HashMap<>();
   private Set labels = new HashSet<>();

   private transient ReadWriteLock readWriteLock = new ReentrantReadWriteLock();

   public void incCount(String actual, String predicted) {
      readWriteLock.writeLock().lock();
      try{
         labels.add(actual);
         labels.add(predicted);
         TupleTwo key = new TupleTwo<>(actual, predicted);
         matrix.put(key, matrix.getOrDefault(key, 0) + 1);
      }finally {
         readWriteLock.writeLock().unlock();
      }
   }

   public List getLabels(){
      List result = new ArrayList<>();

      readWriteLock.readLock().lock();
      try {
         result.addAll(labels.stream().collect(Collectors.toList()));
      } finally {
         readWriteLock.readLock().unlock();
      }

      return result;
   }

   public void setLabels(List labels) {
      readWriteLock.writeLock().lock();
      try {
         this.labels.clear();
         this.labels.addAll(labels);
      }finally {
         readWriteLock.writeLock().unlock();
      }
   }

   // sum of a row representing class c, which is sum of cases that truely belong to class c
   public int getRowSum(String actual) {
      List list = this.getLabels();
      int sum = 0;
      for(int i=0; i < list.size(); ++i) {
         String predicted = list.get(i);
         sum += getCount(actual, predicted);
      }
      return sum;
   }


   // sum of a column representing class c, which is sum of cases the classifiers claims to belong to class c
   public int getColumnSum(String predicted) {
      List list = this.getLabels();
      int sum = 0;
      for(int i=0; i < list.size(); ++i) {
         String actual = list.get(i);
         sum += getCount(actual, predicted);
      }
      return sum;
   }



   public int getCount(String actual, String predicted) {
      int value = 0;
      readWriteLock.readLock().lock();
      try{
         value = matrix.getOrDefault(new TupleTwo<>(actual, predicted), 0);
      }finally {
         readWriteLock.readLock().unlock();
      }
      return value;
   }


   public void reset() {
      readWriteLock.writeLock().lock();
      try {
         matrix.clear();
      } finally {
         readWriteLock.writeLock().unlock();
      }
   }


   public Map, Integer> getMatrix() {
      return matrix;
   }


   public void setMatrix(Map, Integer> matrix) {
      this.matrix = matrix;
   }

   private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
      in.defaultReadObject();

      readWriteLock = new ReentrantReadWriteLock();
   }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy